In [1]:
import numpy as np 
import pandas as pd 
import shutil
import os
import zipfile
import torch
import torch.nn as nn
import cv2
import matplotlib.pyplot as plt
import torchvision
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
from torch.autograd import Variable
from torchvision import transforms
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
import copy
import tqdm
import time
import random
from PIL import Image
import albumentations
from albumentations import pytorch as AT

In [2]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

In [5]:
def infer(weights, path_to_dataset):
    seed_everything(42)

    data_transforms = albumentations.Compose([
        albumentations.Resize(256, 256),
        AT.ToTensor()
    ])

    model = torchvision.models.resnet101(pretrained=True, progress=True)
    for param in model.parameters():
        param.requires_grad = False
    model.fc = torch.nn.Linear(model.fc.in_features, 22)    

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    if not os.path.exists(weights):
        print('File weights does not exists')
        return
    
    model.load_state_dict(torch.load(weights))
    model.eval()

    pred_to_label = {
        0 : 'baseball', 
        1 : 'formula1', 
        2 : 'fencing', 
        3 : 'motogp', 
        4 : 'ice_hockey', 
        5 : 'wrestling', 
        6 : 'boxing', 
        7 : 'volleyball', 
        8 : 'cricket', 
        9 : 'basketball', 
        10 : 'wwe', 
        11 : 'swimming', 
        12 : 'weight_lifting', 
        13 : 'gymnastics', 
        14 : 'tennis', 
        15 : 'kabaddi', 
        16 : 'badminton', 
        17 : 'football', 
        18 : 'table_tennis', 
        19 : 'hockey', 
        20 : 'shooting', 
        21 : 'chess'
    }

    file_list = os.listdir(path_to_dataset)

    labels = []

    for file_name in file_list:
        image = cv2.imread(os.path.join(path_to_dataset, file_name))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        with torch.no_grad():
            transformed_img = data_transforms(image=image)['image'].to(device)

            output = model(transformed_img[None, ...].float())
            pred = torch.sigmoid(output).argmax().to('cpu').item()

            labels.append(pred_to_label[pred])

    df = pd.DataFrame(list(zip(file_list, labels)), columns = ('image', 'sports'))
    df.to_csv('preds_list.csv', sep = ',', index = False)

In [6]:
infer('/content/model.pt', '/content/dataset')