In [1]:
import sys
import os 
sys.path.insert(1, os.path.split(os.getcwd())[0])

In [2]:
from datasets.dataloader import DataLoader
from model import Trainer
from utils.utils import create_folders
from batch_gen import BatchGenerator

import torch 

from transformer import TransformerClassifier, TransfromerTrainer

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


## DataLoader : 

In [4]:
class Args(): 
    def __init__(self, *args, **kwargs):
        self.train_data = 'bslcp'
        self.test_data = 'bslcp'
        self.i3d_training = 'i3d_kinetics_bslcp_981'
        self.num_in_frames = 16
        self.features_dim = 1024
        self.weights = 'opt'
        self.regression = 0 
        self.feature_normalization = 0
        self.eval_use_CP = 0

        self.action = 'train'
        self.seed = 0 
        self.refresh = 'store_true'

        ## Transformer : 
        self.nhead = 4
        self.nhid = 1024
        self.dim_feedforward = 1024
        self.num_layers = 6
        self.dropout = 0

        ## MSTCN : 
        self.num_stages = 4
        self.num_layers = 10 
        self.num_f_maps = 64
        self.features_dim = 1024
        self.bz = 8 
        self.lr = 0.0005
        self.lr_mul = 1
        self.num_epochs = 50
        self.extract_epoch = 10 
        self.classification_threshold = 0.5
        
        ## Optimization
        self.n_warmup_steps = 100
        ## save model : 
        self.use_pseudo_labels = 'store_true'
        self.pretrained = False
        self.folder = ''
        
args = Args()

In [5]:
# load train dataset and test dataset

print(f'Load train data: {args.train_data}')
train_loader = DataLoader(args, args.train_data, 'train')
print(f'Load test data: {args.test_data}')
test_loader = DataLoader(args, args.test_data, 'test')

Load train data: bslcp
Load test data: bslcp


In [6]:
## Some infos : 
print("number of class : ", train_loader.num_classes)
print("cross entropy loss weigths : ", train_loader.weights)
print('number of videos in train : ', len(train_loader.vid_list))
print('number of videos in test : ', len(test_loader.vid_list))

number of class :  2
cross entropy loss weigths :  [0.11247607877029446, 0.8875239212297056]
number of videos in train :  5413
number of videos in test :  702


In [7]:
model_load_dir, model_save_dir, results_save_dir = create_folders(args)
print(model_save_dir)

Saved options to ./exps//models/classification/traindata_bslcp/i3d_kinetics_bslcp_981/supervised/4_10_64_1024_8_0.0005_weighted_opt/seed_0/opt.txt
./exps//models/classification/traindata_bslcp/i3d_kinetics_bslcp_981/supervised/4_10_64_1024_8_0.0005_weighted_opt/seed_0


# Transformer model : 

In [8]:
nhid = 1024  # the dimension of the feedforward network model in nn.TransformerEncoder
dim_feedforward = 1024
nlayers = 4  # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 8  # the number of heads in the multiheadattention models
dropout = 0.1  # the dropout value

nclasses = 2 # for classification task only

model = TransformerClassifier(nhead, nhid, dim_feedforward, nlayers, nclasses, dropout).to(device)

## Trainer : 

In [9]:
trainer = TransfromerTrainer(
    args.nhead,
    args.nhid,
    args.dim_feedforward,
    args.num_layers,
    nclasses,
    args.dropout, 
    device,
    train_loader.weights,
    model_save_dir
    )

In [10]:
eval_args = [
    args,
    model_save_dir,
    results_save_dir,
    test_loader.features_dict,
    test_loader.gt_dict,
    test_loader.eval_gt_dict,
    test_loader.vid_list,
    args.num_epochs,
    device,
    'eval',
    args.classification_threshold,
]

In [11]:
batch_gen = BatchGenerator(
        train_loader.num_classes,
        train_loader.gt_dict,
        train_loader.features_dict,
        train_loader.eval_gt_dict
        )

batch_gen.read_data(train_loader.vid_list)

In [None]:
trainer.train(
    model_save_dir,
    batch_gen,
    args.num_epochs,
    args.bz,
    args.lr,
    device,
    eval_args,
    args.lr_mul,
    args.n_warmup_steps,
    pretrained=model_load_dir)

(50/676.625) Batch: 0.3s | Total: 0:00:15 | ETA: 0:03:48 | LR: 0.0015625 | Loss: 0.7196210622787476
(100/676.625) Batch: 0.3s | Total: 0:00:31 | ETA: 0:03:34 | LR: 0.003125 | Loss: 0.7217093110084534
(150/676.625) Batch: 0.3s | Total: 0:00:48 | ETA: 0:03:40 | LR: 0.0025516 | Loss: 0.7106060981750488
(200/676.625) Batch: 0.3s | Total: 0:01:04 | ETA: 0:03:41 | LR: 0.0022097 | Loss: 0.7090731859207153
(250/676.625) Batch: 0.3s | Total: 0:01:21 | ETA: 0:03:29 | LR: 0.0019764 | Loss: 0.7054544687271118
(300/676.625) Batch: 0.3s | Total: 0:01:37 | ETA: 0:03:17 | LR: 0.0018042 | Loss: 0.6983776688575745
(350/676.625) Batch: 0.3s | Total: 0:01:53 | ETA: 0:02:55 | LR: 0.0016704 | Loss: 0.7040632367134094
(400/676.625) Batch: 0.3s | Total: 0:02:09 | ETA: 0:03:17 | LR: 0.0015625 | Loss: 0.7003839015960693
(450/676.625) Batch: 0.3s | Total: 0:02:26 | ETA: 0:03:20 | LR: 0.0014731 | Loss: 0.6980838179588318
(500/676.625) Batch: 0.3s | Total: 0:02:43 | ETA: 0:03:53 | LR: 0.0013975 | Loss: 0.700640559