In [1]:
from torchvision import transforms
from torch.utils import data
from tqdm import tqdm
import os
from vid_dataset_framed import *
from tvn1 import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
vid_transforms = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((200, 200)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [4]:
gpu = torch.device('cuda:0')

In [5]:
torch.cuda.empty_cache()

In [6]:
train_dataset = VideoDataset('/DATA/ichuviliaeva/videos/UCF50/', 'train_index.txt', vid_transforms)

val_dataset = VideoDataset('/DATA/ichuviliaeva/videos/UCF50/', 'test_index.txt', vid_transforms)

In [7]:
train_dataloader = data.DataLoader(train_dataset, batch_size = 4, shuffle = True)
val_dataloader = data.DataLoader(val_dataset, batch_size = 4)

In [8]:
model = TVN1(50).to(gpu)

In [9]:
optimizer = torch.optim.Adam(model.parameters())

In [10]:
criterion = nn.CrossEntropyLoss()

In [11]:
def train(epoch = 1, verbose = 2, model = model, optimazer = optimizer, criterion = criterion, 
          train_dataloader = train_dataloader, val_dataloader = val_dataloader):
    for t in range(epoch):
        loss_list = []
        for x, vid_lens, labels in tqdm(train_dataloader):
            x = torch.stack([x[b, v] for b in range(x.shape[0]) for v in range(x.shape[1])])
            x = x.to(gpu)
            vid_lens = vid_lens.to(gpu)
            labels = labels.to(gpu)
            optimizer.zero_grad()
            res = model((x, vid_lens))
            loss = criterion(res, labels)
            loss_list.append(loss.detach())
            loss.backward()
            optimizer.step()

        print('epoch ', t, ':')
        print('mean loss = ', torch.mean(torch.tensor(loss_list)))
            
        if t % verbose == 0 or t == epoch - 1:
            mi = 0
            acc = 0
            with torch.no_grad():
                loss_val_list = []
                for x, vid_lens, labels in tqdm(val_dataloader):
                    x = torch.stack([x[b, v] for b in range(x.shape[0]) for v in range(x.shape[1])])
                    x = x.to(gpu)
                    vid_lens = vid_lens.to(gpu)
                    labels = labels.to(gpu)
                    predicts = model((x, vid_lens))
                    loss = criterion(predicts, labels)
                    loss_val_list.append(loss.detach())
                    acc += torch.sum(torch.eq(torch.argmax(predicts, dim=-1), labels).to(dtype=torch.float64)).item()
                    mi += x.shape[0]
                print('mean val loss = ', torch.mean(torch.tensor(loss_val_list)))
                print('accuracy = ', acc / (mi + (mi == 0)))
                
                if t % verbose % 2 == 0 or t == epoch - 1:
                    torch.save(model.state_dict(), 'tvn1-epoch-' + str(t) + '-framed.pth')

In [None]:
train(epoch = 40)

100%|███████████████████████████████████████| 1420/1420 [13:20<00:00,  1.77it/s]


epoch  0 :
mean loss =  tensor(3.9117)


100%|█████████████████████████████████████████| 251/251 [01:55<00:00,  2.18it/s]


mean val loss =  tensor(3.9076)
accuracy =  0.0037387836490528413


100%|███████████████████████████████████████| 1420/1420 [15:08<00:00,  1.56it/s]


epoch  1 :
mean loss =  tensor(3.9067)


100%|███████████████████████████████████████| 1420/1420 [14:37<00:00,  1.62it/s]


epoch  2 :
mean loss =  tensor(3.9055)


100%|█████████████████████████████████████████| 251/251 [01:52<00:00,  2.24it/s]


mean val loss =  tensor(3.9034)
accuracy =  0.0037387836490528413


100%|███████████████████████████████████████| 1420/1420 [14:32<00:00,  1.63it/s]


epoch  3 :
mean loss =  tensor(3.9059)


100%|███████████████████████████████████████| 1420/1420 [14:34<00:00,  1.62it/s]


epoch  4 :
mean loss =  tensor(3.9049)


100%|█████████████████████████████████████████| 251/251 [01:51<00:00,  2.25it/s]


mean val loss =  tensor(3.9089)
accuracy =  0.0024925224327018943


100%|███████████████████████████████████████| 1420/1420 [14:22<00:00,  1.65it/s]


epoch  5 :
mean loss =  tensor(3.9059)


100%|███████████████████████████████████████| 1420/1420 [14:32<00:00,  1.63it/s]


epoch  6 :
mean loss =  tensor(3.9050)


100%|█████████████████████████████████████████| 251/251 [01:44<00:00,  2.41it/s]


mean val loss =  tensor(3.9034)
accuracy =  0.0037387836490528413


100%|███████████████████████████████████████| 1420/1420 [14:29<00:00,  1.63it/s]


epoch  7 :
mean loss =  tensor(3.9050)


100%|███████████████████████████████████████| 1420/1420 [14:33<00:00,  1.63it/s]


epoch  8 :
mean loss =  tensor(3.9050)


100%|█████████████████████████████████████████| 251/251 [01:46<00:00,  2.36it/s]


mean val loss =  tensor(3.9029)
accuracy =  0.0037387836490528413


100%|███████████████████████████████████████| 1420/1420 [14:37<00:00,  1.62it/s]


epoch  9 :
mean loss =  tensor(3.9045)


100%|███████████████████████████████████████| 1420/1420 [14:28<00:00,  1.63it/s]


epoch  10 :
mean loss =  tensor(3.9049)


100%|█████████████████████████████████████████| 251/251 [01:44<00:00,  2.39it/s]


mean val loss =  tensor(3.9056)
accuracy =  0.0037387836490528413


100%|███████████████████████████████████████| 1420/1420 [14:31<00:00,  1.63it/s]


epoch  11 :
mean loss =  tensor(3.9048)


100%|███████████████████████████████████████| 1420/1420 [14:28<00:00,  1.63it/s]


epoch  12 :
mean loss =  tensor(3.9051)


100%|█████████████████████████████████████████| 251/251 [01:44<00:00,  2.40it/s]


mean val loss =  tensor(3.9073)
accuracy =  0.0027417746759720836


100%|███████████████████████████████████████| 1420/1420 [14:33<00:00,  1.63it/s]


epoch  13 :
mean loss =  tensor(3.9040)


100%|███████████████████████████████████████| 1420/1420 [14:31<00:00,  1.63it/s]


epoch  14 :
mean loss =  tensor(3.9051)


100%|█████████████████████████████████████████| 251/251 [01:45<00:00,  2.38it/s]


mean val loss =  tensor(3.9027)
accuracy =  0.0037387836490528413


100%|███████████████████████████████████████| 1420/1420 [14:29<00:00,  1.63it/s]


epoch  15 :
mean loss =  tensor(3.9051)


100%|███████████████████████████████████████| 1420/1420 [14:31<00:00,  1.63it/s]


epoch  16 :
mean loss =  tensor(3.9047)


100%|█████████████████████████████████████████| 251/251 [01:45<00:00,  2.38it/s]


mean val loss =  tensor(3.9056)
accuracy =  0.0037387836490528413


100%|███████████████████████████████████████| 1420/1420 [12:54<00:00,  1.83it/s]


epoch  17 :
mean loss =  tensor(3.9055)


100%|███████████████████████████████████████| 1420/1420 [13:01<00:00,  1.82it/s]


epoch  18 :
mean loss =  tensor(3.9054)


100%|█████████████████████████████████████████| 251/251 [01:43<00:00,  2.42it/s]


mean val loss =  tensor(3.9028)
accuracy =  0.0037387836490528413


100%|███████████████████████████████████████| 1420/1420 [13:02<00:00,  1.82it/s]


epoch  19 :
mean loss =  tensor(3.9051)


100%|███████████████████████████████████████| 1420/1420 [12:47<00:00,  1.85it/s]


epoch  20 :
mean loss =  tensor(3.9049)


100%|█████████████████████████████████████████| 251/251 [01:39<00:00,  2.51it/s]


mean val loss =  tensor(3.9050)
accuracy =  0.0037387836490528413


100%|███████████████████████████████████████| 1420/1420 [12:51<00:00,  1.84it/s]


epoch  21 :
mean loss =  tensor(3.9052)


100%|███████████████████████████████████████| 1420/1420 [12:56<00:00,  1.83it/s]


epoch  22 :
mean loss =  tensor(3.9052)


100%|█████████████████████████████████████████| 251/251 [01:41<00:00,  2.48it/s]


mean val loss =  tensor(3.9059)
accuracy =  0.003115653040877368


100%|███████████████████████████████████████| 1420/1420 [12:56<00:00,  1.83it/s]


epoch  23 :
mean loss =  tensor(3.9051)


100%|███████████████████████████████████████| 1420/1420 [16:15<00:00,  1.46it/s]


epoch  24 :
mean loss =  tensor(3.9053)


100%|█████████████████████████████████████████| 251/251 [01:59<00:00,  2.09it/s]


mean val loss =  tensor(3.9018)
accuracy =  0.0037387836490528413


  8%|███▏                                    | 111/1420 [01:04<11:52,  1.84it/s]