In [1]:
import os
import time
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from torch.utils.data import DataLoader, ConcatDataset

from Model import RepNet
from Dataset import CountixDataset

In [2]:
def training_loop(n_epochs,
                  optimizer,
                  lr_scheduler,
                  model,
                  train_set,
                  val_set,
                  batch_size,
                  lastCkptPath = None):

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")
    
    prevEpoch = 0
    if lastCkptPath != None :
        checkpoint = torch.load(lastCkptPath)
        prevEpoch = checkpoint['epoch']

        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)
    else :
        if os.path.exists("trainingLosses.csv"):
            os.remove("trainingLosses.csv")
        if os.path.exists("validationLosses.csv"):
            os.remove("validationLosses.csv")

    model.to(device)

    loss_func1 = torch.nn.CrossEntropyLoss()
    loss_func2 = torch.nn.BCEWithLogitsLoss()
    
    for epoch in range(prevEpoch, n_epochs + prevEpoch):
        
        if os.path.exists("trainingLosses.csv"): 
            df = pd.read_csv("trainingLosses.csv")
            lastEpochLosses = df[df['Epoch'] == prevEpoch]['Loss']
            
            last_batch = -1 if len(df) == 0 else df.loc[len(df) - 1, 'Batch']
            if (last_batch+1) * batch_size >= len(train_set):
                train_loss_sum = 0
                last_batch = -1
                
                if epoch == prevEpoch:
                    #skip training and go to validation
                    temp_dataset = torch.utils.data.Subset(train_set, range(0, 1))
                else:
                    temp_dataset = train_set
                    
            else :
                temp_dataset = torch.utils.data.Subset(train_set, range(last_batch+1, len(train_set)))
                train_loss_sum = lastEpochLosses.sum()
            
            del(df)
            tl = open("trainingLosses.csv", 'a')
        else:
            train_loss_sum = 0.0
            last_batch = -1
            tl = open("trainingLosses.csv", 'a')
            tl.write("Epoch,Batch,Loss\n")
            temp_dataset = train_set

        #empty iteration through train_loader
        train_loader = DataLoader(temp_dataset, batch_size=batch_size, num_workers=2, drop_last = True)

        pbar = tqdm(train_loader,
                    initial = last_batch + 1,
                    total = len(train_loader) - last_batch-1)
        
        for X, y1, y2 in pbar:
            
            torch.cuda.empty_cache()
            model.train()
            X = X.to(device).float()
            y1 = y1.to(device).long()
            y2 = y2.to(device).float()
            ypred1, ypred2, _ = model(X)
            
            loss1 = loss_func1(ypred1, y1)
            loss2 = loss_func2(ypred2, y2)
            loss = loss1 + loss2
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss_sum += loss.item()
            last_batch += 1

            pbar.set_postfix({'Epoch': epoch, 
                              'Training Loss':(train_loss_sum/(last_batch+1))})

            #save checkpoint
            checkpoint = {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }
            torch.save(checkpoint, 
                       'checkpoint/repnet.pt')
            #append training loss
            cstl = str(epoch) + "," + str(last_batch) + "," + str(loss.item()) + "\n"
            tl.write(cstl)
            tl.flush()
        
        tl.close()

        #validation loop
        with torch.no_grad():
            
            if os.path.exists("validationLosses.csv"): 
                df = pd.read_csv("validationLosses.csv")
                lastEpochLosses = df[df['Epoch'] == prevEpoch]['Loss']
                last_batch = -1 if len(df) == 0 else df.loc[len(df) - 1, 'Batch']
                if (last_batch+1)*batch_size >= len(val_set):
                    val_loss_sum = 0
                    last_batch = -1
                    
                    if epoch == prevEpoch:
                        #skip training and go to validation
                        temp_dataset = torch.utils.data.Subset(val_set, range(0, 1))
                    else:
                        temp_dataset = val_set
                else :
                    temp_dataset = torch.utils.data.Subset(val_set, range(last_batch+1, len(val_set)))
                    val_loss_sum = lastEpochLosses.sum()

                del(df)
                vl = open("validationLosses.csv", 'a')
            else:
                val_loss_sum = 0.0
                last_batch = -1
                vl = open("validationLosses.csv", 'a')
                vl.write("Epoch,Batch,Loss\n")
                temp_dataset = val_set

            val_loader = DataLoader(temp_dataset, batch_size=batch_size, num_workers=2, drop_last = True)
            pbar = tqdm(val_loader, 
                        initial = last_batch + 1, 
                        total = len(val_loader) - last_batch - 1)
            for X, y1, y2 in pbar:
                
                torch.cuda.empty_cache()
                X = X.to(device).float()
                y1 = y1.to(device).long()
                y2 = y2.to(device).float()
                
                model.eval()
                ypred1, ypred2, _ = model(X)

                loss1 = loss_func1(ypred1, y1)             #period length
                loss2 = loss_func2(ypred2, y2)             #periodicity
                loss = loss1 + loss2
                
                val_loss_sum += loss.item()
                last_batch += 1

                pbar.set_postfix({'Epoch':epoch, 
                                  'Validation Loss':(val_loss_sum/(last_batch+1))})
                
                #append training loss
                csvl = str(epoch) + "," + str(last_batch) + "," + str(loss.item()) + "\n"
                vl.write(csvl)
                vl.flush()
            
            vl.close()
        
        lr_scheduler.step()


def trainTestSplit(dataset, TTR):
    trainDataset = torch.utils.data.Subset(dataset, range(0, int(TTR * len(dataset))))
    valDataset = torch.utils.data.Subset(dataset, range(int(TTR*len(dataset)), len(dataset)))
    return trainDataset, valDataset

print("done")


done


In [3]:
frame_per_vid = 64
trainPath = 'countix/countix_train.csv'
testPath = 'countix/countix_test.csv'
valPath = 'countix/countix_val.csv'

testDataset = CountixDataset(testPath, frame_per_vid)
trainDataset = CountixDataset(trainPath, frame_per_vid)
valDataset = CountixDataset(valPath, frame_per_vid)

model =  RepNet(frame_per_vid)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.8)

print("done")

done


In [4]:
"""Testing the training loop with sample datasets"""

sampleDatasetA = torch.utils.data.Subset(trainDataset, range(0, 12))
sampleDatasetB = torch.utils.data.Subset(trainDataset, range(12, 24))
sampleLoaderA = DataLoader(sampleDatasetA, batch_size=1, num_workers=2)
sampleLoaderB = DataLoader(sampleDatasetB, batch_size=1, num_workers=2)


In [5]:
training_loop(10, 
              optimizer, 
              lr_scheduler, 
              model, 
              trainDataset, 
              testDataset,
              1
             )

#              lastCkptPath = 'checkpoint/repnet.pt')

  1%|          | 30/4503 [12:41<31:31:12, 25.37s/it, Epoch=0, Training Loss=10.1]


KeyboardInterrupt: 

In [24]:

sampleDataset = torch.utils.data.Subset(trainDataset, range(0, 0))
sampleLoader = DataLoader(sampleDataset, batch_size=1, num_workers=2)

next(iter(sampleLoader))


StopIteration: 

In [23]:
for i in range(0,0):
    print(i)

In [5]:
"""verify dataset"""
X, y1, y2 = trainDataset[30]
print("X shape ", X.shape)
print("y1 shape ", y1.shape)
print("y2 shape ", y2.shape)

start and end time 41.54 44 M7Urc9uBjBU


KeyboardInterrupt: 

In [5]:
"""verify dataloader"""
trainLoader = DataLoader(trainDataset, batch_size=4, num_workers=1)

X, y1, y2 = next(iter(trainLoader))
print("X shape ", X.shape)
print("y1 shape ", y1.shape)
print("y2 shape ", y2.shape)

X shape  torch.Size([4, 64, 3, 112, 112])
y1 shape  torch.Size([4, 64])
y2 shape  torch.Size([4, 64, 1])


In [6]:
"""verify model"""
y1pred, y2pred, _= model(X)

print("y1pred shape ", y1pred.shape)
print("y2pred shape ", y2pred.shape)

y1pred shape  torch.Size([4, 32, 64])
y2pred shape  torch.Size([4, 64, 1])


In [7]:
print(y1.type())
print(y1pred.type())

torch.LongTensor
torch.FloatTensor


In [8]:
loss_func1 = torch.nn.CrossEntropyLoss()

loss = loss_func1(y1pred, y1)
print(loss.item())

3.5541436672210693


In [None]:
lis = [1,2,3,4]
print(lis[:6])