In [1]:
import sys
import os 
sys.path.insert(1, os.path.split(os.getcwd())[0])

In [2]:
from datasets.dataloader import DataLoader
from model import Trainer
from utils.utils import create_folders
from batch_gen import BatchGenerator

import torch 

from transformer import TransformerClassifier, TransfromerTrainer

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


## DataLoader : 

In [13]:
class Args(): 
    def __init__(self, *args, **kwargs):
        self.train_data = 'bslcp'
        self.test_data = 'bslcp'
        self.i3d_training = 'i3d_kinetics_bslcp_981'
        self.num_in_frames = 16
        self.features_dim = 1024
        self.weights = 'opt'
        self.regression = 0 
        self.feature_normalization = 0
        self.eval_use_CP = 0

        self.action = 'train'
        self.seed = 0 
        self.refresh = 'store_true'

        ## Transformer : 
        self.nhead = 4
        self.nhid = 1024
        self.dim_feedforward = 1024
        self.num_layers = 6
        self.dropout = 0

        ## MSTCN : 
        self.num_stages = 4
        self.num_layers = 10 
        self.num_f_maps = 64
        self.features_dim = 1024
        self.bz = 8 ## ?? 
        self.lr = 0.0005
        self.num_epochs = 50
        self.extract_epoch = 10 
        self.classification_threshold = 0.5

        ## save model : 
        self.use_pseudo_labels = 'store_true'
        self.pretrained = False
        self.folder = ''
        
args = Args()

In [14]:
# load train dataset and test dataset

print(f'Load train data: {args.train_data}')
train_loader = DataLoader(args, args.train_data, 'train')
print(f'Load test data: {args.test_data}')
test_loader = DataLoader(args, args.test_data, 'test')

Load train data: bslcp
Load test data: bslcp


In [15]:
## Some infos : 
print("number of class : ", train_loader.num_classes)
print("cross entropy loss weigths : ", train_loader.weights)
print('number of videos in train : ', len(train_loader.vid_list))
print('number of videos in test : ', len(test_loader.vid_list))

number of class :  2
cross entropy loss weigths :  [0.11247607877029446, 0.8875239212297056]
number of videos in train :  5413
number of videos in test :  702


In [16]:
model_load_dir, model_save_dir, results_save_dir = create_folders(args)
print(model_save_dir)

Saved options to ./exps//models/classification/traindata_bslcp/i3d_kinetics_bslcp_981/supervised/4_10_64_1024_8_0.0005_weighted_opt/seed_0/opt.txt
./exps//models/classification/traindata_bslcp/i3d_kinetics_bslcp_981/supervised/4_10_64_1024_8_0.0005_weighted_opt/seed_0


# Transformer model : 

In [17]:
nhid = 1024  # the dimension of the feedforward network model in nn.TransformerEncoder
dim_feedforward = 1024
nlayers = 4  # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 2  # the number of heads in the multiheadattention models
dropout = 0  # the dropout value

nclasses = 2 # for classification task only

model = TransformerClassifier(nhead, nhid, dim_feedforward, nlayers, nclasses, dropout).to(device)

## Trainer : 

In [18]:
trainer = TransfromerTrainer(
    args.nhead,
    args.nhid,
    args.dim_feedforward,
    args.num_layers,
    2,
    args.dropout, 
    device,
    train_loader.weights,
    model_save_dir
    )

In [19]:
eval_args = [
    args,
    model_save_dir,
    results_save_dir,
    test_loader.features_dict,
    test_loader.gt_dict,
    test_loader.eval_gt_dict,
    test_loader.vid_list,
    args.num_epochs,
    device,
    'eval',
    args.classification_threshold,
]

In [20]:
batch_gen = BatchGenerator(
        train_loader.num_classes,
        train_loader.gt_dict,
        train_loader.features_dict,
        train_loader.eval_gt_dict
        )

batch_gen.read_data(train_loader.vid_list)

In [None]:
trainer.train(
    model_save_dir,
    batch_gen,
    args.num_epochs,
    args.bz,
    args.lr,
    device,
    eval_args,
    pretrained=model_load_dir)

(2/676.625) Batch: 0.5s | Total: 0:00:00 | ETA: 0:05:17 | Loss: 0.9071649312973022
(3/676.625) Batch: 0.4s | Total: 0:00:00 | ETA: 0:04:04 | Loss: 3.4232547283172607
(4/676.625) Batch: 0.3s | Total: 0:00:00 | ETA: 0:03:41 | Loss: 2.1804428100585938
(5/676.625) Batch: 0.3s | Total: 0:00:01 | ETA: 0:03:49 | Loss: 2.3132171630859375
(6/676.625) Batch: 0.3s | Total: 0:00:01 | ETA: 0:03:47 | Loss: 0.7140686511993408
(7/676.625) Batch: 0.3s | Total: 0:00:01 | ETA: 0:03:35 | Loss: 0.9903855919837952
(8/676.625) Batch: 0.3s | Total: 0:00:02 | ETA: 0:03:31 | Loss: 2.1184756755828857
(9/676.625) Batch: 0.3s | Total: 0:00:02 | ETA: 0:03:28 | Loss: 0.8072359561920166
(10/676.625) Batch: 0.3s | Total: 0:00:02 | ETA: 0:03:26 | Loss: 0.837419867515564
(11/676.625) Batch: 0.3s | Total: 0:00:02 | ETA: 0:03:23 | Loss: 1.4872561693191528
(12/676.625) Batch: 0.3s | Total: 0:00:03 | ETA: 0:03:10 | Loss: 1.597318172454834
(13/676.625) Batch: 0.3s | Total: 0:00:03 | ETA: 0:03:09 | Loss: 1.1555500030517578
(1