In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torchvision import transforms
import videotransforms
from dataset import Dataset, calculate_accuracy, make_dataset
from timm.models import create_model




  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    elif torch.backends.mps.is_available():
        return torch.device('mps')
    else:
        return torch.device('cpu')

device = get_device()
print(f"Using device: {device}")

Using device: mps


In [3]:
train_transforms = transforms.Compose([
    videotransforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transforms = transforms.Compose([
    videotransforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [4]:
num_classes = 31
root = '../../Desktop/MLResearch/i3d_smarthome/mp4/'
batch_size = 16
protocol = "CS"

In [5]:
train_dataset = Dataset('./splits/train_cs.txt', 'train', root, train_transforms, protocol)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

val_dataset = Dataset('./splits/validation_cs.txt', 'val', root, test_transforms, protocol)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

dataloaders = {'train': train_dataloader, 'val': val_dataloader}
datasets = {'train': train_dataset, 'val': val_dataset}

Video file not found: ../../Desktop/MLResearch/i3d_smarthome/mp4/Pour.Fromcup_p13_r01_v20_c06.mp4
Video file not found: ../../Desktop/MLResearch/i3d_smarthome/mp4/Pour.Fromcup_p13_r00_v20_c06.mp4
Video file not found: ../../Desktop/MLResearch/i3d_smarthome/mp4/Pour.Fromcup_p19_r00_v08_c01.mp4


In [6]:
# initialize model (swin transformer)
model = create_model('swin_base_patch4_window7_224', pretrained=True, num_classes=num_classes)
model.to(device)
model = nn.DataParallel(model)

In [7]:
# define learning rate and optimizer
init_lr = 0.01
optimizer = optim.AdamW(model.parameters(), lr=init_lr, weight_decay=0.01)
lr_sched = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=10, verbose=True)



In [8]:
# Training and validation function
def run_training(max_steps=100, save_model='weights/'):
    steps = 0
    while steps < max_steps:
        print(f'Step {steps}/{max_steps}')
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train(True)
            else:
                model.train(False)  # Set model to evaluate mode

            tot_loss = 0.0
            tot_cls_loss = 0.0
            tot_acc = 0.0
            num_iter = 0
            optimizer.zero_grad()
            
            for data in dataloaders[phase]:
                num_iter += 1
                inputs, labels = data
                inputs = Variable(inputs.to(device))
                labels = Variable(labels.to(device))

                outputs = model(inputs)
                criterion = nn.CrossEntropyLoss().to(device)
                cls_loss = criterion(outputs, torch.max(labels, dim=1)[1].long())
                tot_cls_loss += cls_loss.data

                loss = cls_loss
                tot_loss += loss.data
                loss.backward()
                acc = calculate_accuracy(outputs, torch.max(labels, dim=1)[1])
                tot_acc += acc
                if phase == 'train':
                    optimizer.step()
                    optimizer.zero_grad()

            if phase == 'train':
                print(f'{phase} Cls Loss: {tot_cls_loss/num_iter:.4f} Tot Loss: {tot_loss/num_iter:.4f}, Acc: {tot_acc/num_iter:.4f}')
                torch.save(model.module.state_dict(), os.path.join(save_model, f'{steps:06d}.pt'))
                tot_loss = tot_cls_loss = tot_acc = 0.0
                steps += 1
            if phase == 'val':
                lr_sched.step(tot_cls_loss/num_iter)
                print(f'{phase} Cls Loss: {tot_cls_loss/num_iter:.4f} Tot Loss: {tot_loss/num_iter:.4f}, Acc: {tot_acc/num_iter:.4f}')

In [9]:
run_training(max_steps=100, save_model='weights/')

Step 0/100
----------


KeyboardInterrupt: 