In [1]:
from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
from torch.utils.data import DataLoader
import torch.nn.functional as F
plt.ion()   # interactive mode
from utils.ImagesDataset import ImagesDataset
from tqdm import tqdm

ENBLE_GPU=True

if ENBLE_GPU:
    cudnn.benchmark = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
else:
    device = torch.device("cpu")


torch.cuda.is_available()

True

In [2]:

class CNN_RNN(nn.Module):
    def __init__(self, classes=17):
        super(CNN_RNN, self).__init__()
        #load trained resnet model
        self.resnet= models.resnet18(pretrained=True)
        self.num_ftrs = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(self.num_ftrs, 17)
        self.resnet.load_state_dict(torch.load('./models/resnet18.pt'))
        
        #freeze weight
        for param in self.resnet.parameters():
            param.requires_grad = False
        
        #remove last class output layer
        self.resnet=torch.nn.Sequential(*(list(self.resnet.children())[:-1]))
        self.lstm = nn.LSTM(input_size=256, hidden_size=256, num_layers=3)
       
    def forward(self, x_3d):
        hidden = None
        toret  = []
        for t in range(x_3d.size(1)):
            with torch.no_grad():
            x = self.resnet(x_3d[:, t, :, :, :])  
            out, hidden = self.lstm(x.unsqueeze(0), hidden)         

            x = self.fc1(out[-1, :, :])
            x = F.relu(x)
            x = self.fc2(x)
            toret.append(x)
        return torch.stack(toret)
    


In [29]:
class CNNLSTM(nn.Module):
    def __init__(self):
        super(CNNLSTM, self).__init__()
        self.resnet= models.resnet18(pretrained=True)
        self.num_ftrs = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(self.num_ftrs, 17)
        self.resnet.load_state_dict(torch.load('./models/resnet18.pt'))
        #freeze weight
        for param in self.resnet.parameters():
            param.requires_grad = False
        
        self.resnet.fc = nn.Sequential(nn.Linear(in_features=512, out_features=256, bias=True))
        
        self.lstm = nn.LSTM(input_size=256, hidden_size=256, num_layers=3)
        self.fc1 = nn.Linear(256, 128)
        self.fc2 = nn.Linear(128, 17)
       
    def forward(self, x_3d):
        #x3d:  torch.Size([128, 32, 1, 80, 80])
        hidden = None
        toret  = []
        for t in range(x_3d.size(1)):
            with torch.no_grad():
                x = self.resnet(x_3d[:, t, :, :, :])
            
            out, hidden = self.lstm(x.unsqueeze(0), hidden)         
            x = self.fc1(out[-1, :, :])
            x = F.relu(x)
            x = self.fc2(x)
            #print("x shape: ", x.shape)
            
            toret.append(x)
        return torch.stack(toret).permute(1, 0, 2)

In [4]:
from utils.VideosDataset import VideosDataset
dataset = VideosDataset()
dataset.__len__()

70

In [5]:
l=dataset.__len__()
val_split=0.3
train_set, val_set = torch.utils.data.random_split(dataset, [l-int(val_split*l), int(val_split*l)],generator=torch.Generator().manual_seed(42))


In [6]:
bs=2
dataset_sizes = {'train':l-int(val_split*l),'val': int(val_split*l)}
train_loader = DataLoader(dataset=train_set,
                          batch_size=bs,
                          shuffle=True,
                          num_workers=1)
val_loader = DataLoader(dataset=val_set,
                          batch_size=bs,
                          shuffle=True,
                          num_workers=1)
dataloaders={'train':train_loader,'val':val_loader}

In [7]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
   
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            running_number_of_preds=0.0
            # Iterate over data.
            with tqdm(dataloaders[phase], unit="batch") as tepoch:
                batches=0
                for inputs, labels in tepoch:
                    tepoch.set_description(f"Epoch {epoch}")
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    # zero the parameter gradients
                    optimizer.zero_grad()
                    
                    # forward
                    # track history if only in train
                    with torch.set_grad_enabled(phase == 'train'):
                        
                        outputs = model(inputs)
                        
                        outputs = torch.reshape(outputs, (-1,17))
                        labels = labels.view(-1)
                        _, preds = torch.max(outputs, 1)
                        #print('outputs',outputs.shape)
                        #print('labels',labels.shape)
                        loss = criterion(outputs, labels)

                        # backward + optimize only if in training phase
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()

                    # statistics
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)
                    batches+=1    
                    
                    tot_len=len(preds)
                    running_number_of_preds+=tot_len
                    #print(running_corrects)
                    tepoch.set_postfix(loss=running_loss/(running_number_of_preds), accuracy=100. * running_corrects.item()/(running_number_of_preds))

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / running_number_of_preds
            epoch_acc = running_corrects.double() / running_number_of_preds

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [30]:
model = CNNLSTM()
model.to(device)

criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 2 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=40, gamma=0.1)

dataset = VideosDataset()
first_data = dataset[0]
features, labels = first_data
print(type(features), type(labels))
print(features.shape, labels.shape)

In [31]:
model = train_model(model, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=70)

Epoch 0/69
----------


Epoch 0: 100%|████████████████████████████████████| 25/25 [00:03<00:00,  7.10batch/s, accuracy=52.7, loss=0.274]


train Loss: 0.2740 Acc: 0.5265


Epoch 0: 100%|████████████████████████████████████| 11/11 [00:00<00:00, 14.92batch/s, accuracy=83.8, loss=0.264]


val Loss: 0.2644 Acc: 0.8381

Epoch 1/69
----------


Epoch 1: 100%|█████████████████████████████████████| 25/25 [00:03<00:00,  7.15batch/s, accuracy=73.3, loss=0.26]


train Loss: 0.2596 Acc: 0.7327


Epoch 1: 100%|████████████████████████████████████| 11/11 [00:00<00:00, 15.08batch/s, accuracy=83.8, loss=0.246]


val Loss: 0.2464 Acc: 0.8381

Epoch 2/69
----------


Epoch 2: 100%|████████████████████████████████████| 25/25 [00:03<00:00,  7.15batch/s, accuracy=73.3, loss=0.244]


train Loss: 0.2436 Acc: 0.7327


Epoch 2: 100%|████████████████████████████████████| 11/11 [00:00<00:00, 15.81batch/s, accuracy=83.8, loss=0.228]


val Loss: 0.2282 Acc: 0.8381

Epoch 3/69
----------


Epoch 3: 100%|████████████████████████████████████| 25/25 [00:03<00:00,  7.33batch/s, accuracy=73.3, loss=0.228]


train Loss: 0.2276 Acc: 0.7327


Epoch 3:  36%|█████████████▊                        | 4/11 [00:00<00:00,  9.80batch/s, accuracy=100, loss=0.195]


KeyboardInterrupt: 