In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.autograd import Variable
from torchvision import transforms
import videotransforms
import matplotlib as plt
from dataset import Dataset, calculate_accuracy, make_dataset
from timm.models import create_model
from utils import custom_collate_fn




  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    elif torch.backends.mps.is_available():
        return torch.device('mps')
    else:
        return torch.device('cpu')

device = get_device()
print(f"Using device: {device}")

Using device: mps


In [3]:
train_transforms = transforms.Compose([
    videotransforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transforms = transforms.Compose([
    videotransforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [4]:
num_classes = 31
root = '../../Desktop/MLResearch/i3d_smarthome/mp4/'
batch_size = 16
protocol = "CS"

In [5]:
train_dataset = Dataset('./splits/train_cs.txt', 'train', root, "rgb", train_transforms, protocol)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, collate_fn=custom_collate_fn)

val_dataset = Dataset('./splits/validation_cs.txt', 'val', root, "rgb", test_transforms, protocol)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, collate_fn=custom_collate_fn)

dataloaders = {'train': train_dataloader, 'val': val_dataloader}
datasets = {'train': train_dataset, 'val': val_dataset}

Video file not found: ../../Desktop/MLResearch/i3d_smarthome/mp4/Pour.Fromcup_p13_r01_v20_c06.mp4
Video file not found: ../../Desktop/MLResearch/i3d_smarthome/mp4/Pour.Fromcup_p13_r00_v20_c06.mp4
Video file not found: ../../Desktop/MLResearch/i3d_smarthome/mp4/Pour.Fromcup_p19_r00_v08_c01.mp4


In [6]:
# Access the first entry in the training dataset
first_video, first_label = datasets['train'][0]
second_video, second_label = datasets['train'][1]

# Print the number of frames and the shape of the video tensor
print(f"Number of frames in the first video: {first_video.shape[1]}")
print(f"Shape of the video tensor: {first_video.shape}")
print(f"Label of the first video: {first_label}")

print(f"Number of frames in the second video: {second_video.shape[1]}")

# Repeat for the validation dataset
first_val_video, first_val_label = datasets['val'][0]

print(f"Number of frames in the first validation video: {first_val_video.shape[1]}")
print(f"Shape of the validation video tensor: {first_val_video.shape}")
print(f"Label of the first validation video: {first_val_label}")


Number of frames in the first video: 64
Shape of the video tensor: torch.Size([3, 64, 224, 224])
Label of the first video: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       dtype=torch.float64)
Number of frames in the second video: 48
Number of frames in the first validation video: 64
Shape of the validation video tensor: torch.Size([3, 64, 224, 224])
Label of the first validation video: tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       dtype=torch.float64)


In [7]:
# initialize model (swin transformer)
model = create_model('swin_base_patch4_window7_224', pretrained=True, num_classes=num_classes)
model.to(device)
model = nn.DataParallel(model)

In [8]:
# define learning rate and optimizer
init_lr = 0.01
optimizer = optim.AdamW(model.parameters(), lr=init_lr, weight_decay=0.01)
lr_sched = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=10, verbose=True)



In [9]:
# Training and validation function
def run_training(max_steps=100, save_model='weights/'):
    steps = 0
    while steps < max_steps:
        print(f'Step {steps}/{max_steps}')
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train(True)
            else:
                model.train(False)  # Set model to evaluate mode

            tot_loss = 0.0
            tot_cls_loss = 0.0
            tot_acc = 0.0
            num_iter = 0
            optimizer.zero_grad()
            
            for data in dataloaders[phase]:
                num_iter += 1
                inputs, labels = data
                inputs = Variable(inputs.to(device))
                labels = Variable(labels.to(device))

                outputs = model(inputs)
                criterion = nn.CrossEntropyLoss().to(device)
                cls_loss = criterion(outputs, torch.max(labels, dim=1)[1].long())
                tot_cls_loss += cls_loss.data

                loss = cls_loss
                tot_loss += loss.data
                loss.backward()
                acc = calculate_accuracy(outputs, torch.max(labels, dim=1)[1])
                tot_acc += acc
                if phase == 'train':
                    optimizer.step()
                    optimizer.zero_grad()

            if phase == 'train':
                print(f'{phase} Cls Loss: {tot_cls_loss/num_iter:.4f} Tot Loss: {tot_loss/num_iter:.4f}, Acc: {tot_acc/num_iter:.4f}')
                torch.save(model.module.state_dict(), os.path.join(save_model, f'{steps:06d}.pt'))
                tot_loss = tot_cls_loss = tot_acc = 0.0
                steps += 1
            if phase == 'val':
                lr_sched.step(tot_cls_loss/num_iter)
                print(f'{phase} Cls Loss: {tot_cls_loss/num_iter:.4f} Tot Loss: {tot_loss/num_iter:.4f}, Acc: {tot_acc/num_iter:.4f}')

In [10]:
run_training(max_steps=100, save_model='weights/')

Step 0/100
----------


TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.