In [1]:
import copy
import glob
import os
import random
import re


import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from sklearn.metrics import accuracy_score
from torch import nn
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset
from torchvision import models
from tqdm import tqdm

import wandb

def seed_everything(seed=1234):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(2023)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
wandb.init(
    project='video_classification',
    name=f'frames_resnet18_exp4',
    config={
        "epochs": 10,
        "batch_size": 4,
        "num_classes": 15
        #"timestamps": 4,
    },
)

[34m[1mwandb[0m: Currently logged in as: [33mdmitryai[0m ([33mcv-itmo[0m). Use [1m`wandb login --relogin`[0m to force relogin


## Dataset

In [3]:
class VideoDataset(Dataset):
    def __init__(self, path, transform):   
        self.path = path
        self.transform = transform

        if self.path[-1] == '/':
            self.path = self.path[:-1]
        self.frames_path = glob.glob(f'{path}/*/*/*/*/*.jpg')

    def __len__(self):
        return len(self.frames_path)

    def __getitem__(self, idx):
        img_path = self.frames_path[idx]
        label = img_path.split('/')[-5]

        image = Image.open(img_path)
        
        if self.transform:
            image = self.transform(image)

        return image, int(label)

In [4]:
h, w = 224, 224
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

In [5]:
train_transforms = transforms.Compose([
            transforms.Resize((h,w)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomAffine(degrees=0, translate=(0.1,0.1)),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
            ])

In [6]:
train_path = 'data/dancing_classes/train/'
val_path = 'data/dancing_classes/val/'

train_ds = VideoDataset(path=train_path, transform=train_transforms)
val_ds = VideoDataset(path=val_path, transform=train_transforms)

In [7]:
val_ds.__getitem__(100)[0].shape

torch.Size([3, 224, 224])

In [8]:
train_dataloader = DataLoader(train_ds, batch_size=4, shuffle=True)
val_dataloader = DataLoader(val_ds, batch_size=4, shuffle=False)

## Model

In [9]:
num_classes = 15

model = models.resnet18(weights = 'ResNet18_Weights.IMAGENET1K_V1')
in_features = model.fc.in_features 
model.fc = nn.Linear(in_features, num_classes)

In [10]:
x,y = next(iter(train_dataloader))
x.shape, y.shape

(torch.Size([4, 3, 224, 224]), torch.Size([4]))

In [11]:
with torch.no_grad():
    y_pred = model(x)

In [12]:
print(y_pred)
torch.argmax(y_pred,dim=1), y

tensor([[ 0.1306,  0.2246, -0.0965,  0.7036,  0.5584,  0.5220,  0.2671,  0.2731,
          0.0797, -0.4518, -0.4624,  0.1569, -0.0666,  0.0050,  0.4203],
        [-0.0073, -0.2855, -0.1801,  0.3643,  0.3471,  0.0073,  0.3385,  0.1524,
          0.0640,  0.3817,  0.1087, -0.0565,  0.0150,  0.0330,  0.5160],
        [-0.6420,  0.0615,  0.3035,  1.3193,  0.2855,  0.0889,  0.7638,  0.8355,
          0.4275,  0.3666,  0.5812, -0.1144, -0.2759, -0.5773,  0.3460],
        [-0.1438, -0.3745, -0.9264,  0.1535,  0.8385, -0.0587,  0.1829, -0.3209,
          0.0775, -0.3144, -0.7160, -0.1761, -0.2518,  0.1232,  1.1149]])


(tensor([ 3, 14,  3, 14]), tensor([ 8,  1,  3, 13]))

In [13]:
EPOCHS = 5
DEVICE = 'cpu'
EXP_PATH = 'experiments/exp4'
criterion = nn.CrossEntropyLoss(reduction="sum")
optimizer = torch.optim.Adam(model.parameters(), lr=3e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min',factor=0.5, patience=5,verbose=1)

In [14]:
def train_model():
    best_model = None
    best_acc = 0
    best_epoch = 0
    
    checkpoint_save_path = f'{EXP_PATH}/checkpoints'
    os.makedirs(checkpoint_save_path,exist_ok=True)
    
    for epoch in range(EPOCHS):

        model.train()
        train_loss = 0
        train_labels, train_preds = [], []

        for step, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=f'Epoch: {epoch}'):
            x_train, y_train = batch
            x_train = x_train.to(DEVICE)
            y_train = y_train.to(DEVICE)

            optimizer.zero_grad()
            
            preds = model(x_train)

            loss = criterion(preds, y_train)

            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()

            preds = preds.argmax(dim=1)
            train_labels.extend(y_train.numpy())
            train_preds.extend(preds.numpy())

        train_loss = train_loss / len(train_dataloader)
        train_acc = accuracy_score(train_labels, train_preds)
        print("Train Loss: {0:.5f}".format(train_loss))
        print("Train Accuracy: {0:.5f}".format(train_acc))

        
        model.eval()
        val_labels, val_preds = [], []
        with torch.no_grad():
            for step, batch in tqdm(enumerate(val_dataloader), total=len(val_dataloader), desc=f'Epoch: {epoch}'):
                x_val, y_val = batch
                x_val = x_val.to(DEVICE)
                y_val = y_val.to(DEVICE)
                
                preds = model(x_val)

                preds = preds.argmax(dim=1)
                val_labels.extend(y_val.numpy())
                val_preds.extend(preds.numpy())

        val_acc = accuracy_score(val_labels, val_preds)
        print("Val Accuracy: {0:.5f}".format(val_acc))

        if val_acc > best_acc:
            best_model = copy.deepcopy(model)
            best_acc = val_acc
            best_epoch = epoch
            torch.save(best_model, f'{checkpoint_save_path}/checkpoint_{best_acc}_{best_epoch}ep.pth')
        
        scheduler.step(loss)

        wandb.log({"Train/Loss" : train_loss}, step=epoch)
        wandb.log({"Train/Accuracy" : train_acc}, step=epoch)
        wandb.log({"Val/Accuracy" : val_acc}, step=epoch)

In [None]:
train_model()

## Testing

In [17]:
model = torch.load('experiments/exp4/checkpoints/checkpoint_0.34114620551665_3ep.pth', map_location='cpu')
model.eval()
print()




In [18]:
path = 'data/dancing_classes/test/'
folders = [os.path.join(path,i) for i in os.listdir(path) if not i.endswith('.csv')]
len(folders)

15

In [19]:
folders[:4]

['data/dancing_classes/test/9',
 'data/dancing_classes/test/0',
 'data/dancing_classes/test/11',
 'data/dancing_classes/test/7']

In [20]:
h, w = 224, 224
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

test_transforms = transforms.Compose([
            transforms.Resize((h,w)),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
            ])

In [21]:
true_labels = []
predictions = []

for folder in tqdm(folders):
    images = glob.glob(os.path.join(folder,'cut_videos/images/*/*.jpg'))

    preds = []
    for img_path in images:
        img = Image.open((img_path))
        img_tensor = test_transforms(img)
        img_tensor = torch.unsqueeze(img_tensor,dim=0)
        
        with torch.no_grad():
            pred = model(img_tensor)
        pred_label = torch.argmax(pred, dim=1).item()

        preds.append(pred_label)
    true_labels.append(int(folder.split('/')[3]))
    predictions.append(max(set(preds), key=preds.count)) # save most common label among all pred labels

100%|██████████| 15/15 [02:25<00:00,  9.69s/it]


In [22]:
accuracy_score(true_labels, predictions)

0.4