In [1]:
import copy
import os
import random
import re

import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from sklearn.metrics import accuracy_score
from torch import nn
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset
from torchvision import models
from tqdm import tqdm

from src.model import Resnt18Rnn, Identity
from src.prediction import predict_label
import wandb

def seed_everything(seed=1234):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(2023)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
wandb.init(
    project='video_classification',
    name=f'ResnetLSTM_exp3',
    config={
        "num_classes": 15,
        "epochs": 10,
        "batch_size": 4,
        "timestamps": 4,
    },
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdmitryai[0m ([33mcv-itmo[0m). Use [1m`wandb login --relogin`[0m to force relogin


## Dataset

In [3]:
class VideoDataset(Dataset):
    def __init__(self, path, timestamps, transform):   
        self.path = path
        self.transform = transform
        self.timestamps = timestamps

        self.frames_path = []
        labels = [i for i in os.listdir(self.path) if not i.endswith('.csv')]
        for label in tqdm(labels):
            label_path = os.path.join(os.path.join(self.path, label), 'cut_videos/images')
            folders = [i for i in os.listdir(label_path) if '.DS' not in i]
            for folder in folders:
                folder_path = os.path.join(label_path, folder)
                folder_frames = [os.path.join(folder_path, i) for i in os.listdir(folder_path)]
                folder_frames.sort(key=self.__natural_keys__)
                self.frames_path.extend(folder_frames)
        print(len(self.frames_path))
        
    def __atoi__(self, text):
        return int(text) if text.isdigit() else text

    def __natural_keys__(self, text):
        return [self.__atoi__(c) for c in re.split(r'(\d+)', text)]

    def __len__(self):
        return int(len(self.frames_path) / self.timestamps)

    def __getitem__(self, idx):
        path2imgs = self.frames_path[idx*self.timestamps:(idx+1)*self.timestamps]
        label = path2imgs[1].split('/')[-5]

        frames = []
        for img_path in path2imgs:
            frame = Image.open(img_path)
            frames.append(frame)
        
        for i, frame in enumerate(frames):
            frame = self.transform(frame)
            frames[i] = frame
        if len(frames)>0:
            frames = torch.stack(frames)
        return frames, int(label)

In [4]:
# transform params
h, w = 224, 224
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

In [5]:
train_transforms = transforms.Compose([
            transforms.Resize((h,w)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomAffine(degrees=0, translate=(0.1,0.1)),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
            ])

In [6]:
train_path = '/Users/dmitry/Desktop/cv_itmo/kinetics_video_classification/data/dancing_classes/train/'
val_path = '/Users/dmitry/Desktop/cv_itmo/kinetics_video_classification/data/dancing_classes/val/'

train_ds = VideoDataset(path=train_path, timestamps=4, transform=train_transforms)
val_ds = VideoDataset(path=val_path, timestamps=4, transform=train_transforms)

100%|██████████| 15/15 [00:00<00:00, 49.34it/s]


68985


100%|██████████| 15/15 [00:00<00:00, 475.57it/s]

6997





In [7]:
val_ds.__getitem__(100)[0].shape

torch.Size([4, 3, 224, 224])

In [8]:
def collate_fn_rnn(batch):
    imgs_batch, label_batch = list(zip(*batch))
    imgs_batch = [imgs for imgs in imgs_batch if len(imgs)>0]
    label_batch = [torch.tensor(l) for l, imgs in zip(label_batch, imgs_batch) if len(imgs)>0]
    imgs_tensor = torch.stack(imgs_batch)
    labels_tensor = torch.stack(label_batch)
    return imgs_tensor, labels_tensor

In [9]:
train_dataloader = DataLoader(train_ds, batch_size=4, shuffle=True, collate_fn=collate_fn_rnn)
val_dataloader = DataLoader(val_ds, batch_size=4, shuffle=False, collate_fn=collate_fn_rnn)

## Model

In [10]:
num_classes = 15
params_model={
    "num_classes": num_classes,
    "dr_rate": 0.1,
    "weights" : 'ResNet18_Weights.IMAGENET1K_V1',
    "rnn_num_layers": 1,
    "rnn_hidden_size": 100,}
model = Resnt18Rnn(params_model)

In [11]:
x,y = next(iter(train_dataloader))
x.shape, y.shape

(torch.Size([4, 4, 3, 224, 224]), torch.Size([4]))

In [12]:
with torch.no_grad():
    y_pred = model(x)

In [13]:
print(y_pred)
torch.argmax(y_pred,dim=1), y

tensor([[ 0.0833,  0.0790,  0.1301,  0.0314, -0.3942,  0.0745,  0.1487,  0.2896,
          0.2166,  0.0855,  0.1833,  0.0677,  0.0642,  0.3404, -0.1077],
        [ 0.1898, -0.1707,  0.2747,  0.0863, -0.4098,  0.2068,  0.1302,  0.2429,
          0.2816,  0.0096,  0.1598,  0.0636, -0.1549,  0.3657, -0.1184],
        [ 0.2681, -0.0837,  0.0923,  0.2646, -0.1522, -0.0604,  0.0724, -0.0726,
          0.3242, -0.0113, -0.0414,  0.0299, -0.1013,  0.5147, -0.0511],
        [ 0.2139, -0.0356,  0.2055,  0.1244, -0.1714, -0.0193,  0.0915,  0.1428,
          0.2525,  0.0847, -0.0321, -0.0231, -0.1442,  0.2625, -0.0798]])


(tensor([13, 13, 13, 13]), tensor([12, 12, 14,  0]))

In [14]:
EPOCHS = 10
DEVICE = 'cpu'
EXP_PATH = 'experiments/exp3'
criterion = nn.CrossEntropyLoss(reduction="sum")
optimizer = torch.optim.Adam(model.parameters(), lr=3e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min',factor=0.5, patience=5,verbose=1)

In [15]:
def train_model():
    best_model = None
    best_acc = 0
    best_epoch = 0
    
    checkpoint_save_path = f'{EXP_PATH}/checkpoints'
    os.makedirs(checkpoint_save_path,exist_ok=True)
    
    for epoch in range(EPOCHS):

        model.train()
        train_loss = 0
        train_labels, train_preds = [], []

        for step, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=f'Epoch: {epoch}'):
            x_train, y_train = batch
            x_train = x_train.to(DEVICE)
            y_train = y_train.to(DEVICE)

            optimizer.zero_grad()
            
            preds = model(x_train)

            loss = criterion(preds, y_train)

            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()

            preds = preds.argmax(dim=1)
            train_labels.extend(y_train.numpy())
            train_preds.extend(preds.numpy())

        train_loss = train_loss / len(train_dataloader)
        train_acc = accuracy_score(train_labels, train_preds, )
        print("Train Loss: {0:.5f}".format(train_loss))
        print("Train Accuracy: {0:.5f}".format(train_acc))

        
        model.eval()
        val_labels, val_preds = [], []
        with torch.no_grad():
            for step, batch in tqdm(enumerate(val_dataloader), total=len(val_dataloader), desc=f'Epoch: {epoch}'):
                x_val, y_val = batch
                x_val = x_val.to(DEVICE)
                y_val = y_val.to(DEVICE)
                
                preds = model(x_val)

                preds = preds.argmax(dim=1)
                val_labels.extend(y_val.numpy())
                val_preds.extend(preds.numpy())

        val_acc = accuracy_score(val_labels, val_preds, )
        print("Val Accuracy: {0:.5f}".format(val_acc))

        if val_acc > best_acc:
            best_model = copy.deepcopy(model)
            best_acc = val_acc
            best_epoch = epoch
            torch.save(best_model, f'{checkpoint_save_path}/checkpoint_{best_acc}_{best_epoch}ep.pth')
        
        scheduler.step(loss)

        wandb.log({"Train/Loss" : train_loss}, step=epoch)
        wandb.log({"Train/Accuracy" : train_acc}, step=epoch)
        wandb.log({"Val/Accuracy" : val_acc}, step=epoch)

In [16]:
train_model()

Epoch: 0: 100%|██████████| 4312/4312 [53:02<00:00,  1.35it/s]


Train Loss: 7.98733
Train Accuracy: 0.47286


Epoch: 0: 100%|██████████| 438/438 [02:03<00:00,  3.56it/s]


Val Accuracy: 0.34534


Epoch: 1: 100%|██████████| 4312/4312 [52:26<00:00,  1.37it/s]


Train Loss: 4.09291
Train Accuracy: 0.81097


Epoch: 1: 100%|██████████| 438/438 [02:09<00:00,  3.38it/s]


Val Accuracy: 0.34591


Epoch: 2: 100%|██████████| 4312/4312 [1:01:42<00:00,  1.16it/s]  


Train Loss: 2.17412
Train Accuracy: 0.91528


Epoch: 2: 100%|██████████| 438/438 [01:59<00:00,  3.66it/s]


Val Accuracy: 0.32647


Epoch: 3: 100%|██████████| 4312/4312 [52:48<00:00,  1.36it/s]


Train Loss: 1.28526
Train Accuracy: 0.95112


Epoch: 3: 100%|██████████| 438/438 [02:01<00:00,  3.61it/s]


Val Accuracy: 0.36535


Epoch: 4: 100%|██████████| 4312/4312 [54:06<00:00,  1.33it/s]


Train Loss: 0.82957
Train Accuracy: 0.96712


Epoch: 4: 100%|██████████| 438/438 [02:11<00:00,  3.33it/s]


Val Accuracy: 0.33505


Epoch: 5:  92%|█████████▏| 3956/4312 [49:35<04:25,  1.34it/s]wandb: Network error (ConnectTimeout), entering retry loop.
Epoch: 5: 100%|██████████| 4312/4312 [54:06<00:00,  1.33it/s]


Train Loss: 0.61339
Train Accuracy: 0.97379


Epoch: 5: 100%|██████████| 438/438 [02:10<00:00,  3.36it/s]


Val Accuracy: 0.32419


Epoch: 6:  41%|████      | 1755/4312 [22:04<32:09,  1.33it/s]


KeyboardInterrupt: 

## Testing

In [17]:
model = torch.load('experiments/exp3/checkpoints/checkpoint_0.3653516295025729_3ep.pth', map_location=DEVICE)
model.eval()
print()




In [18]:
# go throw every class and its every folder to get list of folder paths   
path = 'data/dancing_classes/test/'
all_folders = []
for label in [i for i in os.listdir(path) if not i.endswith('.csv')]:
    label_path = os.path.join(os.path.join(path, label), 'cut_videos/images')
    folders = [os.path.join(label_path, i) for i in os.listdir(label_path) if '.DS' not in i]
    all_folders.extend(folders)
print(all_folders[0])

data/dancing_classes/test/9/cut_videos/images/0


In [19]:
test_transforms = transforms.Compose([
            transforms.Resize((h,w)),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
            ])

In [20]:
timestamps = 4 # take every 4 frames of folder to predict label 
labels, predictions = [], [] # lists to save true labels and predictions

for folder_path in tqdm(all_folders):
    true_label, pred_label = predict_label(model,folder_path, timestamps, test_transforms, inference_mode=False)
    
    labels.append(true_label)
    predictions.append(pred_label)

100%|██████████| 64/64 [02:34<00:00,  2.42s/it]


In [21]:
accuracy_score(labels, predictions)

0.34375