In [None]:
'''using torchvision for dataloading'''
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
from psutil import cpu_count
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.tensorboard import SummaryWriter
from dataloader import PicklebotDataset, custom_collate
from mobilenet import MobileNetLarge2D, MobileNetSmall2D, MobileNetSmall3D,MobileNetLarge3D
from movinet import MoViNetA2
from helpers import calculate_accuracy, average_for_plotting

device = 'cuda' if torch.cuda.is_available() else 'cpu'

#hyperparameters
torch.manual_seed(1234)
learning_rate = 3e-4 #we use cosine annealing so this is just a starting point
batch_size = 2 #the paper quotes 128 images/chip, but with video we have to change this
max_iters = 100
eval_interval = 1
weight_decay = 5e-4
std = (0.2104, 0.1986, 0.1829)
mean = (0.3939, 0.3817, 0.3314)
use_autocast = False 
compile = False

#video paths
train_video_paths = '/workspace/picklebotdataset/train'
val_video_paths = '/workspace/picklebotdataset/val'

#annotations paths
train_annotations_file = '/home/henry/Documents/PythonProjects/picklebotdataset/train_labels.csv'
val_annotations_file = '/home/henry/Documents/PythonProjects/picklebotdataset/val_labels.csv'

#video paths
train_video_paths = '/home/henry/Documents/PythonProjects/picklebotdataset/train_all_together'
val_video_paths = '/home/henry/Documents/PythonProjects/picklebotdataset/val_all_together'

#establish our normalization using transforms, 
#note that we are doing this in our dataloader as opposed to in the training loop like with dali
transform = transforms.Normalize(mean,std)

#dataset     
train_dataset = PicklebotDataset(train_annotations_file,train_video_paths,transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size,shuffle=True,collate_fn=custom_collate,num_workers=cpu_count())
val_dataset = PicklebotDataset(val_annotations_file,val_video_paths,transform=transform)
val_loader = DataLoader(val_dataset, batch_size=batch_size,shuffle=True,collate_fn=custom_collate,num_workers=cpu_count())

#define model, initialize weights 
model = MobileNetLarge3D()
# model.initialize_weights()
model = model.to(device)

#for multi-gpu
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)

# optimizer
# optimizer = optim.RMSprop(params=model.parameters(),lr=learning_rate,weight_decay=weight_decay,momentum=momentum,eps=eps) #starting with AdamW for now. 
optimizer = optim.AdamW(params=model.parameters(),lr=learning_rate, weight_decay=weight_decay)

#cosine annealing
scheduler = CosineAnnealingLR(optimizer, T_max=100)

#loss
criterion = nn.CrossEntropyLoss()
if use_autocast:
    scaler = GradScaler()
    
model_name = model.__class__.__name__
writer = SummaryWriter(f'runs/{model_name}') #tensorboard writer 
# checkpoint = torch.load('checkpoints/MobileNetLarge3D17.pth')
# loaded_state_dict_keys = checkpoint.keys()
# updated_state_dict = {}
# for key,value in checkpoint.items():
#     new_key = key.replace('_orig_mod.','') #remove the prefix
#     updated_state_dict[new_key] = value
# model.load_state_dict(updated_state_dict)


if compile:
    print("compiling the model... (takes a ~minute)")
    unoptimized_model = model
    model = torch.compile(model)  # requires PyTorch 2 and a modern gpu, these lines are from karpathy
    print("compilation complete!")


#estimate loss using the val set, and calculate accuracy
@torch.no_grad()
def estimate_loss():
    #evaluate the model
    model.eval()
    val_losses = [] 
    val_correct = 0
    val_samples = 0

    #calculate the loss
    for val_features,val_labels in tqdm(val_loader):
        val_features = val_features.to(device)
        val_labels = val_labels.long() #waiting to move to device until after forward pass, idk if this matters
        # val_labels = val_labels.expand(val_features.shape[2]) #this is only for our lstm T -> batch size, a lame hack    
        val_outputs = model(val_features)

        val_loss = criterion(val_outputs,val_labels.to(device))
        val_losses.append(val_loss.item())
        
        val_correct += calculate_accuracy(val_outputs,val_labels)
        val_samples += len(val_labels)

    avg_val_loss = np.mean(val_losses)
    val_accuracy = val_correct / val_samples
    return avg_val_loss, val_accuracy

#try except block so we can manually early stop while saving the model
#training loop
start_time = time.time()
train_losses = torch.tensor([])
train_percent = torch.tensor([])
val_losses = []
val_percent = []
counter = 0

try:
    for iter in range(max_iters):
        
        model.train()
        train_correct = 0
        train_samples = 0
        batch_loss_list = []
        batch_percent_list = []

        #forward pass
        for batch_idx, (features,labels) in tqdm(enumerate(train_loader)):
            labels = labels.to(torch.int64)
            features = features.to(device)
            # labels = labels.expand(features.shape[2]) #this is only for our lstm T -> batch size, a lame hack
            
            #zero the gradients
            optimizer.zero_grad(set_to_none=True)
            
            if use_autocast:    
                with autocast():
                    outputs = model(features)
                    loss = criterion(outputs,labels.to(device))
                
                #backprop & update weights

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()

            else:
                outputs = model(features)
                loss = criterion(outputs,labels.to(device))

                #backprop & update weights
                loss.backward()
                optimizer.step()

            

            batch_loss_list.append(loss.item()) #append the loss of the batch to our list to be averaged and plotted later, this is dataset size / batch size long
            batch_correct = calculate_accuracy(outputs,labels) #number of correct predictions in the batch
            train_correct += batch_correct #this is the total number of correct predictions so far
            train_samples += len(labels) #this is the total number of samples so far
            batch_percent_list.append(train_correct/train_samples)
            writer.add_scalar('training loss', batch_loss_list[-1], counter)
            writer.add_scalar('training accuracy', batch_percent_list[-1], counter)
            counter += 1

        scheduler.step()
        train_losses = torch.cat((train_losses,average_for_plotting(batch_loss_list))) #train losses is a tensor
        train_percent = torch.cat((train_percent,average_for_plotting(batch_percent_list))) #train percent is a tensor
        elapsed = time.time() - start_time
        remaining_iters = max_iters - iter
        avg_time_per_iter = elapsed / (iter + 1)
        estimated_remaining_time = remaining_iters * avg_time_per_iter

        if iter % eval_interval == 0 or iter == max_iters - 1:
                        
            #evaluate the model
            val_loss, val_accuracy = estimate_loss()
        
            val_losses.append(val_loss) #average loss of the val dataset, this is a scalar
            val_percent.append(val_accuracy) #percent of correct predictions in the val set, this is a scalar


            print(f"step {iter}: train loss:  {train_losses[-1].mean().item():.4f}, val loss: {val_losses[-1]:.4f}") #report the average loss of the batch
            print(f"step {iter}: train accuracy:  {(train_percent[-1].mean().item())*100:.2f}%, val accuracy: {val_percent[-1]*100:.2f}%")
            writer.add_scalar('val loss', val_losses[-1], iter)
            writer.add_scalar('val accuracy',val_percent[-1], iter)
            torch.save(model.state_dict(), f'checkpoints/{model_name}{iter}.pth')

        tqdm.write(f"Iter [{iter+1}/{max_iters}] - Elapsed Time: {elapsed:.2f}s  Remaining Time: [{estimated_remaining_time:.2f}]")
        if iter == max_iters -1:
            print("Training completed:") 
            print(f"Final train loss: {train_losses[-1].mean().item():.4f},")
            print(f"Final val loss: {val_losses[-1]:.4f}, ")
            print(f"Final train accuracy: {(train_percent[-1].mean().item())*100:.2f}%, ")
            print(f"Final val accuracy: {val_percent[-1]*100:.2f}%")

            
except KeyboardInterrupt:
    print(f"Keyboard interrupt,\nFinal train loss: {train_losses[-1].mean().item():.4f}, ")
    print(f"Final val loss: {val_losses[-1]:.4f}, ")
    print(f"Final train accuracy: {(train_percent[-1].mean().item())*100:.2f}%, ")
    print(f"Final val accuracy: {val_percent[-1]*100:.2f}%")

finally:
    torch.save(model.state_dict(), f'checkpoints/{model_name}_finished.pth')
    with open(f'statistics/{model_name}_finished_train_losses.npy', 'wb') as f:
        np.save(f, np.array(train_losses))
    with open(f'statistics/{model_name}_finished_val_losses.npy', 'wb') as f:
        np.save(f, np.array(val_losses))
    with open(f'statistics/{model_name}_finished_train_percent.npy', 'wb') as f:
        np.save(f, np.array(train_percent))
    with open(f'statistics/{model_name}_finished_val_percent.npy', 'wb') as f:
        np.save(f, np.array(val_percent))
    print(f"Model saved!")

In [None]:
'''This version of the program uses Nvidia Dali to load data, not torchvision.io.read_video,
   It should be substantially faster, especially with multiple gpus, perhaps a good setup 
   would be one to load the videos, one to run the training loop? Perhaps not as I learned more about it.

    Eventually, this and the other version in this notebook should be merged into one notebook, with a flag to choose which to use.
   
'''
import os
import torch
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.cuda.amp import GradScaler, autocast
from torch.utils.tensorboard import SummaryWriter
from nvidia.dali.plugin.pytorch import DALIClassificationIterator, LastBatchPolicy
from tqdm import tqdm
from psutil import cpu_count
from mobilenet import MobileNetLarge2D, MobileNetSmall2D, MobileNetSmall3D, MobileNetLarge3D
from movinet import MoViNetA2
from helpers import calculate_accuracy, video_pipeline, average_for_plotting

'''
Our mean is ([0.3939, 0.3817, 0.3314])
Our std is ([0.2104, 0.1986, 0.1829])
'''



'''Strikes are 0, balls 1.'''

device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
#hyperparameters
learning_rate = 3e-4 #the paper quotes rmsprop with 0.1 lr, but we have a tiny batch size, and are using AdamW
batch_size = 64 #the paper quotes 128 images/chip, but with video we have to change this
max_iters = 100
eval_interval = 1
weight_decay = 0.0005
momentum = 0.9
eps = np.sqrt(0.002) #From the pytorch blog post, "a reasonable approximation can be taken with the formula PyTorch_eps = sqrt(TF_eps)."
std = torch.tensor([0.2104, 0.1986, 0.1829])[None,None,None,:]
mean = torch.tensor([0.3539, 0.3817, 0.3314])[None,None,None,:]
use_autocast = True
compile = False

#information for the dali pipeline
sequence_length = 130 #longest videos in our dataset 
initial_prefetch_size = 20

#video paths
train_video_paths = '/workspace/picklebotdataset/train'
val_video_paths = '/workspace/picklebotdataset/val'

num_train_videos = len(os.listdir(train_video_paths + '/' + 'balls')) + len(os.listdir(train_video_paths + '/' + 'strikes'))
num_val_videos = len(os.listdir(val_video_paths + '/' + 'balls')) + len(os.listdir(val_video_paths + '/' + 'strikes'))

#define our model, initialize weights
model = MoViNetA2()
model.initialize_weights()
model = model.to(device)

#for multi-gpu setups 
#may want to revisit this and choose which device we use for loading with dali, and which to use for training the net.
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)

#define our optimizer
#optimizer = optim.RMSprop(params=model.parameters(),lr=learning_rate,weight_decay=weight_decay,momentum=momentum,eps=eps) #starting with AdamW for now. 
optimizer = optim.AdamW(params=model.parameters(), lr=learning_rate, weight_decay=weight_decay)

#cosine annealing
scheduler = CosineAnnealingLR(optimizer, T_max=100)

#loss
criterion = nn.CrossEntropyLoss() 
if use_autocast:
    scaler = GradScaler()
model_name = model.__class__.__name__ 
writer = SummaryWriter(f'runs/{model_name}') #tensorboard writer
# model.load_state_dict(torch.load(f'{model_name}.pth')) #if applicable, load the model from the last checkpoint


if compile:
    print("compiling the model... (takes a ~minute)")
    unoptimized_model = model
    model = torch.compile(model)  # requires PyTorch 2 and a modern gpu, these lines were lifted from karpathy
    print("compilation complete!")

#estimate_loss using validation set, we should refactor this.
@torch.no_grad()
def estimate_loss():
    #evaluate the model
    model.eval()
    val_losses = []
    val_correct = 0
    val_samples = 0

    #calculate the loss
    for _,val_features in tqdm(enumerate(val_loader)):
        val_labels = (val_features[0]['label']).view(-1).long() #need this as a (batch_size,) tensor
        val_features = val_features[0]['data']/255
        val_features = val_features.permute(0,-1,1,2,3) 
        # val_labels = val_labels.expand(val_features.shape[2]) #this is only for our lstm T -> batch size, a lame hack

        val_outputs = model(val_features)
        
        val_loss = criterion(val_outputs,val_labels)
        
        val_losses.append(val_loss.item())  
        
        val_correct += calculate_accuracy(val_outputs,val_labels) #get number of correct
        val_samples += len(labels) #this is the total number of samples so far

    avg_val_loss = np.mean(val_losses)
    val_accuracy = val_correct / val_samples
    return avg_val_loss, val_accuracy


#initialize lists for plotting
start_time = time.time()
train_losses = torch.tensor([])
train_percent = torch.tensor([])
val_losses = []
val_percent = []
counter = 0

#build our pipelines
train_pipe = video_pipeline(batch_size=batch_size, num_threads=cpu_count(), device_id=0, file_root=train_video_paths,
                            sequence_length=sequence_length,initial_prefetch_size=initial_prefetch_size,mean=mean,std=std)
val_pipe = video_pipeline(batch_size=batch_size, num_threads=cpu_count(), device_id=0, file_root=val_video_paths,
                          sequence_length=sequence_length,initial_prefetch_size=initial_prefetch_size,mean=mean,std=std)

train_pipe.build()
val_pipe.build()


train_loader = DALIClassificationIterator(train_pipe, auto_reset=True,last_batch_policy=LastBatchPolicy.PARTIAL, size=num_train_videos)
val_loader = DALIClassificationIterator(val_pipe, auto_reset=True,last_batch_policy=LastBatchPolicy.PARTIAL, size=num_val_videos)

try:
    for iter in range(max_iters):
        
        model.train()
        train_correct = 0
        train_samples = 0
        batch_loss_list = [] #want to overwrite this each epoch
        batch_percent_list = []

        #forward pass
        for batch_idx, features in tqdm(enumerate(train_loader)):
            
            labels = (features[0]['label']).view(-1).long() #need this as a (batch_size,) tensor in int64
            features = features[0]['data']/255 #i think it makes sense to overwrite features to save precious gpu memory
            features = features.permute(0,-1,1,2,3) #reshape for our 3D convolutions
            # labels = labels.expand(features.shape[2]) #this is only for our lstm T -> batch size, a lame hack
            
            #zero the gradients
            optimizer.zero_grad(set_to_none=True)
            
            if use_autocast:    
                with autocast(dtype=dtype):
                    outputs = model(features)
                    loss = criterion(outputs,labels)
                
                #backprop & update weights

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()

            else:
                outputs = model(features)
                loss = criterion(outputs,labels)

                #backprop & update weights
                loss.backward()
                optimizer.step()

            #step the scheduler after the epoch
            scheduler.step()
            batch_loss_list.append(loss.item()) #append the loss of the batch to our list to be averaged and plotted later, this is dataset size / batch size long
            batch_correct = calculate_accuracy(outputs,labels) #number of correct predictions in the batch
            train_correct += batch_correct #this is the total number of correct predictions so far
            train_samples += len(labels) #this is the total number of samples so far
            batch_percent_list.append(train_correct/train_samples)
            writer.add_scalar('training loss', batch_loss_list[-1], counter)
            writer.add_scalar('training accuracy', batch_percent_list[-1], counter)
            counter += 1

        train_losses = torch.cat((train_losses,average_for_plotting(batch_loss_list))) #train losses is a tensor
        train_percent = torch.cat((train_percent,average_for_plotting(batch_percent_list))) #train percent is a tensor
        elapsed = time.time() - start_time
        remaining_iters = max_iters - iter
        avg_time_per_iter = elapsed / (iter + 1)
        estimated_remaining_time = remaining_iters * avg_time_per_iter

        if iter % eval_interval == 0 or iter == max_iters - 1:
                        
            #evaluate the model
            val_loss, val_accuracy = estimate_loss()
        
            val_losses.append(val_loss) #average loss of the val dataset, this is a scalar
            val_percent.append(val_accuracy) #percent of correct predictions in the val set, this is a scalar


            print(f"step {iter}: train loss:  {train_losses[-1].mean().item():.4f}, val loss: {val_losses[-1]:.4f}") #report the average loss of the batch
            print(f"step {iter}: train accuracy:  {(train_percent[-1].mean().item())*100:.2f}%, val accuracy: {val_percent[-1]*100:.2f}%")
            writer.add_scalar('val loss', val_losses[-1], iter)
            writer.add_scalar('val accuracy',val_percent[-1], iter)
            torch.save(model.state_dict(), f'checkpoints/{model_name}{iter}.pth')

        tqdm.write(f"Iter [{iter+1}/{max_iters}] - Elapsed Time: {elapsed:.2f}s  Remaining Time: [{estimated_remaining_time:.2f}]")
        if iter == max_iters -1:
            print("Training completed:") 
            print(f"Final train loss: {train_losses[-1].mean().item():.4f},")
            print(f"Final val loss: {val_losses[-1]:.4f}, ")
            print(f"Final train accuracy: {(train_percent[-1].mean().item())*100:.2f}%, ")
            print(f"Final val accuracy: {val_percent[-1]*100:.2f}%")

            
except KeyboardInterrupt:
    print(f"Keyboard interrupt,\nFinal train loss: {train_losses[-1].mean().item():.4f}, ")
    print(f"Final val loss: {val_losses[-1]:.4f}, ")
    print(f"Final train accuracy: {(train_percent[-1].mean().item())*100:.2f}%, ")
    print(f"Final val accuracy: {val_percent[-1]*100:.2f}%")

finally:
    torch.save(model.state_dict(), f'checkpoints/{model_name}_finished.pth')
    with open(f'statistics/{model_name}_finished_train_losses.npy', 'wb') as f:
        np.save(f, np.array(train_losses))
    with open(f'statistics/{model_name}_finished_val_losses.npy', 'wb') as f:
        np.save(f, np.array(val_losses))
    with open(f'statistics/{model_name}_finished_train_percent.npy', 'wb') as f:
        np.save(f, np.array(train_percent))
    with open(f'statistics/{model_name}_finished_val_percent.npy', 'wb') as f:
        np.save(f, np.array(val_percent))
    print(f"Model saved!")

In [1]:

'''trying with binary cross entropy loss'''
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
from psutil import cpu_count
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.tensorboard import SummaryWriter
from dataloader import PicklebotDataset, custom_collate
from mobilenet import MobileNetLarge2D, MobileNetSmall2D, MobileNetSmall3D,MobileNetLarge3D
from movinet import MoViNetA2
from helpers import calculate_accuracy_bce, average_for_plotting

device = 'cuda' if torch.cuda.is_available() else 'cpu'

#hyperparameters
torch.manual_seed(1234)
learning_rate = 0.000238167787843871 #this seems nuts, but it's the lr we left off at
batch_size = 4
max_iters = 100
eval_interval = 1
weight_decay = 5e-4
std = (0.2104, 0.1986, 0.1829)
mean = (0.3939, 0.3817, 0.3314)
use_autocast = False 
compile = False

#video paths
train_video_paths = '/workspace/picklebotdataset/train'
val_video_paths = '/workspace/picklebotdataset/val'

train_video_paths = '/home/henry/Documents/PythonProjects/picklebotdataset/train_all_together'
val_video_paths = '/home/henry/Documents/PythonProjects/picklebotdataset/val_all_together'

#annotations paths
train_annotations_file = '/home/henry/Documents/PythonProjects/picklebotdataset/train_labels.csv'
val_annotations_file = '/home/henry/Documents/PythonProjects/picklebotdataset/val_labels.csv'

#establish normalization using transforms
transform = transforms.Normalize(mean,std)

#dataset     
train_dataset = PicklebotDataset(train_annotations_file,train_video_paths,transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size,shuffle=True,collate_fn=custom_collate,num_workers=cpu_count())
val_dataset = PicklebotDataset(val_annotations_file,val_video_paths,transform=transform)
val_loader = DataLoader(val_dataset, batch_size=batch_size,shuffle=True,collate_fn=custom_collate,num_workers=cpu_count())

#define model
model = MobileNetSmall3D(num_classes=1)
model = model.to(device)

#for multi-gpu
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)

# optimizer
optimizer = optim.AdamW(params=model.parameters(),lr=learning_rate, weight_decay=weight_decay)

#cosine annealing
scheduler = CosineAnnealingLR(optimizer, T_max=70)

#loss
criterion = nn.BCELoss()  # Update to use Binary Cross Entropy Loss
if use_autocast:
    scaler = GradScaler()
    
model_name = model.__class__.__name__+'bce'
writer = SummaryWriter(f'runs/{model_name}') #tensorboard writer
# checkpoint = torch.load(f'checkpoints/{model_name}30.pth')
# model.load_state_dict(checkpoint)

if compile:
    print("compiling the model... (takes a ~minute)")
    model = torch.compile(model)
    print("compilation complete!")

#estimate loss using the val set, and calculate accuracy
@torch.no_grad()
def estimate_loss():
    model.eval()
    val_losses = []
    val_correct = 0
    val_samples = 0

    for val_features, val_labels in tqdm(val_loader):
        val_features = val_features.to(device)
        val_labels = val_labels.float().to(device)  # Convert labels to float and move to device
        val_labels = val_labels.unsqueeze(1)  # Add a dimension to labels to match output shape
        val_outputs = model(val_features)
        val_outputs = torch.sigmoid(val_outputs)  # Apply sigmoid activation function
        val_loss = criterion(val_outputs, val_labels)
        val_losses.append(val_loss.item())

        val_correct += calculate_accuracy_bce(val_outputs, val_labels)
        val_samples += len(val_labels)

    avg_val_loss = np.mean(val_losses)
    val_accuracy = val_correct / val_samples
    return avg_val_loss, val_accuracy

#training loop
start_time = time.time()
train_losses = torch.tensor([])
train_percent = torch.tensor([])
val_losses = []
val_percent = []
counter = 30

try:
    for iter in range(max_iters):
        model.train()
        train_correct = 0
        train_samples = 0
        batch_loss_list = []
        batch_percent_list = []

        for batch_idx, (features, labels) in tqdm(enumerate(train_loader)):
            labels = labels.float().to(device)  # Convert labels to float and move to device
            labels = labels.unsqueeze(1)  # Add a dimension to labels to match output shape
            features = features.to(device)

            # Zero the gradients
            optimizer.zero_grad(set_to_none=True)

            if use_autocast:
                with autocast():
                    outputs = model(features)
                    outputs = torch.sigmoid(outputs)  # Apply sigmoid activation function
                    loss = criterion(outputs, labels)
                
                # Backprop & update weights
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                outputs = model(features)
                outputs = torch.sigmoid(outputs)  # Apply sigmoid activation function
                loss = criterion(outputs.cpu(), labels.cpu())

                # Backprop & update weights
                loss.backward()
                optimizer.step()

            batch_loss_list.append(loss.item())
            batch_correct = calculate_accuracy_bce(outputs, labels)
            train_correct += batch_correct
            train_samples += len(labels)
            batch_percent_list.append(batch_correct / float(len(labels)))
            writer.add_scalar('training loss', batch_loss_list[-1], counter)
            writer.add_scalar('training accuracy', batch_percent_list[-1], counter)
            counter += 1

        scheduler.step()
        train_losses = torch.cat((train_losses, average_for_plotting(batch_loss_list)))
        train_percent = torch.cat((train_percent, average_for_plotting(batch_percent_list)))
        elapsed = time.time() - start_time
        remaining_iters = max_iters - iter
        avg_time_per_iter = elapsed / (iter + 1)
        estimated_remaining_time = remaining_iters * avg_time_per_iter

        if iter % eval_interval == 0 or iter == max_iters - 1:
            val_loss, val_accuracy = estimate_loss()
            val_losses.append(val_loss)
            val_percent.append(val_accuracy)

            print(f"Step {iter}: Train Loss: {train_losses[-1].mean().item():.4f}, Val Loss: {val_losses[-1]:.4f}")
            print(f"Step {iter}: Train Accuracy: {(train_percent[-1].mean().item())*100:.2f}%, Val Accuracy: {val_percent[-1]*100:.2f}%")
            writer.add_scalar('val loss', val_losses[-1], iter+31)
            writer.add_scalar('val accuracy', val_percent[-1], iter+31)
            torch.save(model.state_dict(), f'checkpoints/{model_name}{iter+31}.pth')

        tqdm.write(f"Iter [{iter+1}/{max_iters}] - Elapsed Time: {elapsed:.2f}s - Remaining Time: [{estimated_remaining_time:.2f}]")

        if iter == max_iters - 1:
            print("Training completed:")
            print(f"Final Train Loss: {train_losses[-1].mean().item():.4f}")
            print(f"Final Val Loss: {val_losses[-1]:.4f}")
            print(f"Final Train Accuracy: {(train_percent[-1].mean().item())*100:.2f}%")
            print(f"Final Val Accuracy: {val_percent[-1]*100:.2f}%")
except KeyboardInterrupt:
    print(f"Keyboard interrupt,\nFinal Train Loss: {train_losses[-1].mean().item():.4f}")
    print(f"Final Val Loss: {val_losses[-1]:.4f}")
    print(f"Final Train Accuracy: {(train_percent[-1].mean().item())*100:.2f}%")
    print(f"Final Val Accuracy: {val_percent[-1]*100:.2f}%")
finally:
    torch.save(model.state_dict(), f'checkpoints/{model_name}_finished.pth')
    with open(f'statistics/{model_name}_finished_train_losses.npy', 'wb') as f:
        np.save(f, train_losses.numpy())
    with open(f'statistics/{model_name}_finished_val_losses.npy', 'wb') as f:
        np.save(f, np.array(val_losses))
    with open(f'statistics/{model_name}_finished_train_percent.npy', 'wb') as f:
        np.save(f, train_percent.numpy())
    with open(f'statistics/{model_name}_finished_val_percent.npy', 'wb') as f:
        np.save(f, np.array(val_percent))
    print(f"Model and statistics saved!")

10773it [2:34:21,  1.16it/s]
  avg_losses = torch.tensor(loss_list[:-partial_size]).view(-1,1000).mean(1)
  avg_partial = torch.tensor(loss_list[-partial_size:]).view(-1,partial_size).mean(1)
100%|██████████| 1348/1348 [19:54<00:00,  1.13it/s]

Step 0: Train Loss: 0.6934, Val Loss: 0.6932
Step 0: Train Accuracy: 49.55%, Val Accuracy: 49.94%
Iter [1/100] - Elapsed Time: 9261.69s - Remaining Time: [926168.86]



10773it [2:34:10,  1.16it/s]
100%|██████████| 1348/1348 [19:53<00:00,  1.13it/s]

Step 1: Train Loss: 0.6935, Val Loss: 0.6932
Step 1: Train Accuracy: 48.58%, Val Accuracy: 50.04%
Iter [2/100] - Elapsed Time: 19707.00s - Remaining Time: [975496.30]



10773it [2:34:17,  1.16it/s]
100%|██████████| 1348/1348 [19:49<00:00,  1.13it/s]

Step 2: Train Loss: 0.6933, Val Loss: 0.6932
Step 2: Train Accuracy: 49.06%, Val Accuracy: 49.70%
Iter [3/100] - Elapsed Time: 30158.29s - Remaining Time: [985170.82]



10773it [2:34:15,  1.16it/s]
100%|██████████| 1348/1348 [20:00<00:00,  1.12it/s]

Step 3: Train Loss: 0.6932, Val Loss: 0.6953
Step 3: Train Accuracy: 51.03%, Val Accuracy: 50.04%
Iter [4/100] - Elapsed Time: 40603.71s - Remaining Time: [984640.01]



10773it [2:40:17,  1.12it/s]
100%|██████████| 1348/1348 [20:20<00:00,  1.10it/s]

Step 4: Train Loss: 0.6934, Val Loss: 0.6932
Step 4: Train Accuracy: 50.03%, Val Accuracy: 50.24%
Iter [5/100] - Elapsed Time: 51421.38s - Remaining Time: [987290.51]



10773it [2:33:49,  1.17it/s]
100%|██████████| 1348/1348 [19:56<00:00,  1.13it/s]

Step 5: Train Loss: 0.6930, Val Loss: 0.6932
Step 5: Train Accuracy: 50.06%, Val Accuracy: 50.48%
Iter [6/100] - Elapsed Time: 61872.14s - Remaining Time: [979642.30]



10773it [2:34:39,  1.16it/s]
100%|██████████| 1348/1348 [19:52<00:00,  1.13it/s]

Step 6: Train Loss: 0.6931, Val Loss: 0.6935
Step 6: Train Accuracy: 51.20%, Val Accuracy: 51.71%
Iter [7/100] - Elapsed Time: 72348.92s - Remaining Time: [971542.61]



10773it [2:34:34,  1.16it/s]
100%|██████████| 1348/1348 [19:57<00:00,  1.13it/s]

Step 7: Train Loss: 0.6850, Val Loss: 0.6841
Step 7: Train Accuracy: 55.85%, Val Accuracy: 55.16%
Iter [8/100] - Elapsed Time: 82816.03s - Remaining Time: [962736.34]



10773it [2:30:57,  1.19it/s]
100%|██████████| 1348/1348 [19:40<00:00,  1.14it/s]

Step 8: Train Loss: 0.6731, Val Loss: 0.6788
Step 8: Train Accuracy: 57.15%, Val Accuracy: 56.82%
Iter [9/100] - Elapsed Time: 93071.00s - Remaining Time: [951392.48]



10773it [2:29:52,  1.20it/s]
100%|██████████| 1348/1348 [19:45<00:00,  1.14it/s]

Step 9: Train Loss: 0.6686, Val Loss: 0.6767
Step 9: Train Accuracy: 58.70%, Val Accuracy: 58.49%
Iter [10/100] - Elapsed Time: 103244.83s - Remaining Time: [939527.92]



10773it [2:31:18,  1.19it/s]
100%|██████████| 1348/1348 [19:39<00:00,  1.14it/s]

Step 10: Train Loss: 0.6594, Val Loss: 0.7025
Step 10: Train Accuracy: 58.83%, Val Accuracy: 61.20%
Iter [11/100] - Elapsed Time: 113509.73s - Remaining Time: [928716.01]



10773it [2:30:51,  1.19it/s]
100%|██████████| 1348/1348 [19:39<00:00,  1.14it/s]

Step 11: Train Loss: 0.6444, Val Loss: 0.6959
Step 11: Train Accuracy: 61.71%, Val Accuracy: 60.66%
Iter [12/100] - Elapsed Time: 123741.17s - Remaining Time: [917746.98]



10773it [2:31:55,  1.18it/s]
100%|██████████| 1348/1348 [19:44<00:00,  1.14it/s]

Step 12: Train Loss: 0.6329, Val Loss: 0.6420
Step 12: Train Accuracy: 62.55%, Val Accuracy: 61.83%
Iter [13/100] - Elapsed Time: 134037.01s - Remaining Time: [907327.46]



10773it [2:32:08,  1.18it/s]
100%|██████████| 1348/1348 [19:45<00:00,  1.14it/s]

Step 13: Train Loss: 0.6353, Val Loss: 0.6321
Step 13: Train Accuracy: 63.78%, Val Accuracy: 64.35%
Iter [14/100] - Elapsed Time: 144350.44s - Remaining Time: [897034.90]



10773it [2:32:13,  1.18it/s]
100%|██████████| 1348/1348 [19:51<00:00,  1.13it/s]

Step 14: Train Loss: 0.6197, Val Loss: 0.6329
Step 14: Train Accuracy: 64.62%, Val Accuracy: 62.04%
Iter [15/100] - Elapsed Time: 154669.16s - Remaining Time: [886769.84]



10773it [2:32:31,  1.18it/s]
100%|██████████| 1348/1348 [19:54<00:00,  1.13it/s]

Step 15: Train Loss: 0.6151, Val Loss: 0.6056
Step 15: Train Accuracy: 65.49%, Val Accuracy: 66.47%
Iter [16/100] - Elapsed Time: 165012.53s - Remaining Time: [876629.06]



10773it [2:31:14,  1.19it/s]
100%|██████████| 1348/1348 [19:40<00:00,  1.14it/s]

Step 16: Train Loss: 0.6001, Val Loss: 0.6132
Step 16: Train Accuracy: 67.66%, Val Accuracy: 66.08%
Iter [17/100] - Elapsed Time: 175281.77s - Remaining Time: [866098.18]



10773it [2:31:26,  1.19it/s]
100%|██████████| 1348/1348 [19:48<00:00,  1.13it/s]

Step 17: Train Loss: 0.5874, Val Loss: 0.6173
Step 17: Train Accuracy: 68.47%, Val Accuracy: 67.16%
Iter [18/100] - Elapsed Time: 185549.33s - Remaining Time: [855588.59]



10773it [2:31:11,  1.19it/s]
100%|██████████| 1348/1348 [19:35<00:00,  1.15it/s]

Step 18: Train Loss: 0.5866, Val Loss: 0.5927
Step 18: Train Accuracy: 68.86%, Val Accuracy: 68.53%
Iter [19/100] - Elapsed Time: 195809.56s - Remaining Time: [845072.86]



10773it [2:31:11,  1.19it/s]
100%|██████████| 1348/1348 [19:31<00:00,  1.15it/s]

Step 19: Train Loss: 0.5788, Val Loss: 0.5760
Step 19: Train Accuracy: 69.47%, Val Accuracy: 69.10%
Iter [20/100] - Elapsed Time: 206056.87s - Remaining Time: [834530.34]



10773it [2:30:46,  1.19it/s]
100%|██████████| 1348/1348 [19:35<00:00,  1.15it/s]

Step 20: Train Loss: 0.5735, Val Loss: 0.5730
Step 20: Train Accuracy: 69.21%, Val Accuracy: 69.10%
Iter [21/100] - Elapsed Time: 216275.65s - Remaining Time: [823907.24]



10773it [2:30:36,  1.19it/s]
100%|██████████| 1348/1348 [19:37<00:00,  1.14it/s]

Step 21: Train Loss: 0.5622, Val Loss: 0.5735
Step 21: Train Accuracy: 70.70%, Val Accuracy: 69.57%
Iter [22/100] - Elapsed Time: 226488.88s - Remaining Time: [813300.98]



10773it [2:30:42,  1.19it/s]
100%|██████████| 1348/1348 [19:39<00:00,  1.14it/s]

Step 22: Train Loss: 0.5641, Val Loss: 0.5557
Step 22: Train Accuracy: 70.02%, Val Accuracy: 70.46%
Iter [23/100] - Elapsed Time: 236709.16s - Remaining Time: [802752.80]



10773it [2:30:48,  1.19it/s]
100%|██████████| 1348/1348 [19:39<00:00,  1.14it/s]

Step 23: Train Loss: 0.5527, Val Loss: 0.5762
Step 23: Train Accuracy: 71.90%, Val Accuracy: 70.01%
Iter [24/100] - Elapsed Time: 246937.56s - Remaining Time: [792257.99]



10773it [2:30:41,  1.19it/s]
100%|██████████| 1348/1348 [19:40<00:00,  1.14it/s]

Step 24: Train Loss: 0.5515, Val Loss: 0.5905
Step 24: Train Accuracy: 71.70%, Val Accuracy: 71.96%
Iter [25/100] - Elapsed Time: 257158.34s - Remaining Time: [781761.34]



10773it [2:30:42,  1.19it/s]
100%|██████████| 1348/1348 [19:35<00:00,  1.15it/s]

Step 25: Train Loss: 0.5397, Val Loss: 0.6537
Step 25: Train Accuracy: 72.09%, Val Accuracy: 72.53%
Iter [26/100] - Elapsed Time: 267382.14s - Remaining Time: [771294.64]



10773it [2:30:33,  1.19it/s]
100%|██████████| 1348/1348 [19:36<00:00,  1.15it/s]

Step 26: Train Loss: 0.5488, Val Loss: 0.5722
Step 26: Train Accuracy: 71.86%, Val Accuracy: 72.13%
Iter [27/100] - Elapsed Time: 277591.36s - Remaining Time: [760805.94]



10773it [2:31:15,  1.19it/s]
100%|██████████| 1348/1348 [19:36<00:00,  1.15it/s]

Step 27: Train Loss: 0.5293, Val Loss: 0.6125
Step 27: Train Accuracy: 73.74%, Val Accuracy: 72.74%
Iter [28/100] - Elapsed Time: 287844.06s - Remaining Time: [750450.59]



10773it [2:30:53,  1.19it/s]
100%|██████████| 1348/1348 [19:29<00:00,  1.15it/s]

Step 28: Train Loss: 0.5333, Val Loss: 0.5532
Step 28: Train Accuracy: 73.19%, Val Accuracy: 72.76%
Iter [29/100] - Elapsed Time: 298073.66s - Remaining Time: [740044.94]



10773it [2:30:51,  1.19it/s]
100%|██████████| 1348/1348 [19:40<00:00,  1.14it/s]

Step 29: Train Loss: 0.5126, Val Loss: 0.5918
Step 29: Train Accuracy: 73.90%, Val Accuracy: 73.15%
Iter [30/100] - Elapsed Time: 308295.30s - Remaining Time: [729632.21]



10773it [2:30:23,  1.19it/s]
100%|██████████| 1348/1348 [19:35<00:00,  1.15it/s]

Step 30: Train Loss: 0.5194, Val Loss: 0.5513
Step 30: Train Accuracy: 74.13%, Val Accuracy: 73.00%
Iter [31/100] - Elapsed Time: 318499.44s - Remaining Time: [719192.28]



10773it [2:31:26,  1.19it/s]
100%|██████████| 1348/1348 [19:38<00:00,  1.14it/s]

Step 31: Train Loss: 0.5120, Val Loss: 0.5623
Step 31: Train Accuracy: 75.03%, Val Accuracy: 73.89%
Iter [32/100] - Elapsed Time: 328762.35s - Remaining Time: [708893.81]



10773it [2:30:21,  1.19it/s]
100%|██████████| 1348/1348 [19:24<00:00,  1.16it/s]

Step 32: Train Loss: 0.5110, Val Loss: 0.6003
Step 32: Train Accuracy: 74.03%, Val Accuracy: 72.96%
Iter [33/100] - Elapsed Time: 338962.42s - Remaining Time: [698468.01]



10773it [2:30:42,  1.19it/s]
100%|██████████| 1348/1348 [19:27<00:00,  1.15it/s]

Step 33: Train Loss: 0.5038, Val Loss: 0.5695
Step 33: Train Accuracy: 75.16%, Val Accuracy: 73.94%
Iter [34/100] - Elapsed Time: 349169.76s - Remaining Time: [688069.81]



10773it [2:30:36,  1.19it/s]
100%|██████████| 1348/1348 [19:38<00:00,  1.14it/s]

Step 34: Train Loss: 0.4882, Val Loss: 0.5514
Step 34: Train Accuracy: 76.75%, Val Accuracy: 74.22%
Iter [35/100] - Elapsed Time: 359373.75s - Remaining Time: [677676.21]



10773it [2:31:12,  1.19it/s]
100%|██████████| 1348/1348 [19:33<00:00,  1.15it/s]

Step 35: Train Loss: 0.4891, Val Loss: 0.6409
Step 35: Train Accuracy: 76.94%, Val Accuracy: 73.57%
Iter [36/100] - Elapsed Time: 369625.19s - Remaining Time: [667378.82]



10773it [2:31:20,  1.19it/s]
100%|██████████| 1348/1348 [19:37<00:00,  1.15it/s]

Step 36: Train Loss: 0.4933, Val Loss: 0.6354
Step 36: Train Accuracy: 77.01%, Val Accuracy: 75.13%
Iter [37/100] - Elapsed Time: 379880.12s - Remaining Time: [657089.93]



10773it [2:30:54,  1.19it/s]
100%|██████████| 1348/1348 [19:38<00:00,  1.14it/s]

Step 37: Train Loss: 0.5079, Val Loss: 0.5847
Step 37: Train Accuracy: 74.45%, Val Accuracy: 75.37%
Iter [38/100] - Elapsed Time: 390111.62s - Remaining Time: [646764.00]



10773it [2:31:14,  1.19it/s]
100%|██████████| 1348/1348 [19:37<00:00,  1.14it/s]

Step 38: Train Loss: 0.4931, Val Loss: 0.6426
Step 38: Train Accuracy: 75.71%, Val Accuracy: 75.61%
Iter [39/100] - Elapsed Time: 400365.17s - Remaining Time: [636477.96]



10773it [2:30:26,  1.19it/s]
100%|██████████| 1348/1348 [19:32<00:00,  1.15it/s]

Step 39: Train Loss: 0.5002, Val Loss: 0.5374
Step 39: Train Accuracy: 75.13%, Val Accuracy: 76.08%
Iter [40/100] - Elapsed Time: 410569.73s - Remaining Time: [626118.83]



10773it [2:31:53,  1.18it/s]
100%|██████████| 1348/1348 [19:37<00:00,  1.14it/s]

Step 40: Train Loss: 0.4646, Val Loss: 0.5936
Step 40: Train Accuracy: 77.62%, Val Accuracy: 75.80%
Iter [41/100] - Elapsed Time: 420855.43s - Remaining Time: [615885.99]



10773it [2:30:35,  1.19it/s]
100%|██████████| 1348/1348 [19:35<00:00,  1.15it/s]

Step 41: Train Loss: 0.4806, Val Loss: 0.7904
Step 41: Train Accuracy: 76.78%, Val Accuracy: 74.33%
Iter [42/100] - Elapsed Time: 431069.05s - Remaining Time: [605549.38]



10773it [2:30:21,  1.19it/s]
100%|██████████| 1348/1348 [19:22<00:00,  1.16it/s]

Step 42: Train Loss: 0.4790, Val Loss: 0.6408
Step 42: Train Accuracy: 76.07%, Val Accuracy: 76.76%
Iter [43/100] - Elapsed Time: 441265.83s - Remaining Time: [595195.77]



10773it [2:31:20,  1.19it/s]
100%|██████████| 1348/1348 [19:44<00:00,  1.14it/s]

Step 43: Train Loss: 0.4630, Val Loss: 0.7722
Step 43: Train Accuracy: 77.20%, Val Accuracy: 76.06%
Iter [44/100] - Elapsed Time: 451508.57s - Remaining Time: [584908.83]



10773it [2:31:24,  1.19it/s]
100%|██████████| 1348/1348 [19:32<00:00,  1.15it/s]

Step 44: Train Loss: 0.4811, Val Loss: 0.5970
Step 44: Train Accuracy: 76.71%, Val Accuracy: 76.84%
Iter [45/100] - Elapsed Time: 461777.87s - Remaining Time: [574656.91]



10773it [2:31:26,  1.19it/s]
100%|██████████| 1348/1348 [19:40<00:00,  1.14it/s]

Step 45: Train Loss: 0.4517, Val Loss: 0.8226
Step 45: Train Accuracy: 78.01%, Val Accuracy: 76.26%
Iter [46/100] - Elapsed Time: 472038.01s - Remaining Time: [564393.27]



10773it [2:31:00,  1.19it/s]
100%|██████████| 1348/1348 [19:40<00:00,  1.14it/s]

Step 46: Train Loss: 0.4636, Val Loss: 0.6604
Step 46: Train Accuracy: 77.49%, Val Accuracy: 76.76%
Iter [47/100] - Elapsed Time: 482279.53s - Remaining Time: [554108.40]



10773it [2:30:41,  1.19it/s]
100%|██████████| 1348/1348 [19:44<00:00,  1.14it/s]

Step 47: Train Loss: 0.4534, Val Loss: 0.7824
Step 47: Train Accuracy: 77.46%, Val Accuracy: 76.22%
Iter [48/100] - Elapsed Time: 492501.89s - Remaining Time: [543804.17]



10773it [2:30:59,  1.19it/s]
100%|██████████| 1348/1348 [19:30<00:00,  1.15it/s]

Step 48: Train Loss: 0.4556, Val Loss: 1.5593
Step 48: Train Accuracy: 77.52%, Val Accuracy: 76.61%
Iter [49/100] - Elapsed Time: 502746.55s - Remaining Time: [533526.95]



10773it [2:30:55,  1.19it/s]
100%|██████████| 1348/1348 [19:35<00:00,  1.15it/s]

Step 49: Train Loss: 0.4453, Val Loss: 1.2938
Step 49: Train Accuracy: 78.69%, Val Accuracy: 76.30%
Iter [50/100] - Elapsed Time: 512972.97s - Remaining Time: [523232.43]



10773it [2:31:32,  1.18it/s]
100%|██████████| 1348/1348 [19:36<00:00,  1.15it/s]

Step 50: Train Loss: 0.4511, Val Loss: 1.3825
Step 50: Train Accuracy: 78.30%, Val Accuracy: 76.19%
Iter [51/100] - Elapsed Time: 523241.73s - Remaining Time: [512982.09]



10773it [2:30:58,  1.19it/s]
100%|██████████| 1348/1348 [19:25<00:00,  1.16it/s]

Step 51: Train Loss: 0.4359, Val Loss: 1.0827
Step 51: Train Accuracy: 79.37%, Val Accuracy: 76.67%
Iter [52/100] - Elapsed Time: 533477.03s - Remaining Time: [502699.51]



10773it [2:30:16,  1.19it/s]
100%|██████████| 1348/1348 [19:35<00:00,  1.15it/s]

Step 52: Train Loss: 0.4459, Val Loss: 1.3612
Step 52: Train Accuracy: 78.53%, Val Accuracy: 76.19%
Iter [53/100] - Elapsed Time: 543658.85s - Remaining Time: [492370.28]



10773it [2:30:44,  1.19it/s]
100%|██████████| 1348/1348 [19:32<00:00,  1.15it/s]

Step 53: Train Loss: 0.4342, Val Loss: 1.0435
Step 53: Train Accuracy: 79.69%, Val Accuracy: 77.23%
Iter [54/100] - Elapsed Time: 553879.06s - Remaining Time: [482079.92]



10773it [2:30:24,  1.19it/s]
100%|██████████| 1348/1348 [19:40<00:00,  1.14it/s]

Step 54: Train Loss: 0.4230, Val Loss: 1.4973
Step 54: Train Accuracy: 79.56%, Val Accuracy: 76.91%
Iter [55/100] - Elapsed Time: 564076.40s - Remaining Time: [471772.98]



10773it [2:31:55,  1.18it/s]
100%|██████████| 1348/1348 [19:43<00:00,  1.14it/s]

Step 55: Train Loss: 0.4244, Val Loss: 1.1493
Step 55: Train Accuracy: 80.24%, Val Accuracy: 75.78%
Iter [56/100] - Elapsed Time: 574372.90s - Remaining Time: [461549.66]



4282it [1:00:31,  1.18it/s]


Keyboard interrupt,
Final Train Loss: 0.4244
Final Val Loss: 1.1493
Final Train Accuracy: 80.24%
Final Val Accuracy: 75.78%
Model and statistics saved!


In [None]:
'''For testing our network'''
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from psutil import cpu_count
from torchvision import transforms
from torch.utils.data import DataLoader
from dataloader import PicklebotDataset, custom_collate
from mobilenet import MobileNetLarge2D, MobileNetSmall2D, MobileNetSmall3D,MobileNetLarge3D
from movinet import MoViNetA2

torch.manual_seed(1234)

def calculate_accuracy_bce(outputs, labels, threshold=0.5):
    # Apply threshold to obtain predicted classes and move to CPU
    preds = (outputs >= threshold).float().cpu()

    # Move labels to CPU
    labels = labels.cpu()
    
    balls_correct = ((preds == 0) & (labels == 0)).sum().item() 
    strikes_correct = ((preds == 1) & (labels == 1)).sum().item() 

    return balls_correct, strikes_correct

@torch.no_grad()
def estimate_loss():
    #evaluate the model
    model.eval()
    test_losses = [] 
    balls_correct = 0
    strikes_correct = 0
    test_samples = 0

    #calculate the loss
    for test_features,test_labels in tqdm(test_loader):
        test_features = test_features.to(device)
        test_labels = test_labels.float().to(device) #waiting to move to device until after forward pass, idk if this matters
        # val_labels = val_labels.expand(val_features.shape[2]) #this is only for our lstm T -> batch size, a lame hack    
        test_outputs = model(test_features)
        test_labels = test_labels.unsqueeze(1)
        test_loss = criterion(test_outputs,test_labels)
        test_losses.append(test_loss.item())
        
        correct_tuple = calculate_accuracy_bce(test_outputs,test_labels)
        balls_correct += correct_tuple[0]
        strikes_correct += correct_tuple[1]
        test_samples += len(test_labels)

    avg_test_loss = np.mean(test_losses)
    balls_accuracy = balls_correct / 2694 #too lazy to not use the magic number from the spreadsheet rn
    strikes_accuracy = strikes_correct / 2693 #too lazy to not use the magic number from the spreadsheet rn
    return avg_test_loss, balls_accuracy, strikes_accuracy 

device = 'cuda' if torch.cuda.is_available() else 'cpu'
std = (0.2104, 0.1986, 0.1829)
mean = (0.3939, 0.3817, 0.3314)
batch_size = 4 

#annotations paths
test_annotations_file = '/home/henry/Documents/PythonProjects/picklebotdataset/test_labels.csv'

#video paths
test_video_paths = '/home/henry/Documents/PythonProjects/picklebotdataset/test_all_together'

#establish our normalization using transforms, 
#note that we are doing this in our dataloader as opposed to in the training loop like with dali
transform = transforms.Normalize(mean,std)

#dataset     
test_dataset = PicklebotDataset(test_annotations_file,test_video_paths,transform=transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size,shuffle=True,collate_fn=custom_collate,num_workers=cpu_count())

model = MobileNetSmall3D(num_classes=1)
criterion = nn.BCEWithLogitsLoss()
for i in range(25,50):
    model.load_state_dict(torch.load(f'checkpoints/MobileNetSmall3Dbce{i}.pth'))
    model.to(device)
    avg_test_loss,balls_accuracy,strikes_accuracy = estimate_loss()

    print(f'mobilenet small {i} test loss: {avg_test_loss:.4f}, mobilenet small ball test accuracy: {balls_accuracy * 100:.2f}% mobilenet small strike test accuracy: {strikes_accuracy * 100:.2f}%')

In [42]:
'''Calculate the number of parameters in each model, for comparison purposes. 
   Note that movinet is about 2.8x larger than mobilenet small, and mobilenet large is about 2.5x larger than mobilenet small.'''

from movinet import MoViNetA2
from mobilenet import MobileNetLarge3D
movinet = MoViNetA2()
mobilenet_large = MobileNetLarge3D()
mobilenet_small = MobileNetSmall3D()

movinet_params = sum(p.numel() for p in movinet.parameters())
mobilenet_large_params = sum(p.numel() for p in mobilenet_large.parameters())
mobilenet_small_params = sum(p.numel() for p in mobilenet_small.parameters())
print(f"number of parameters in movinet: {movinet_params}")
print(f"number of parameters in mobilenet large: {mobilenet_large_params}")
print(f"number of parameters in mobilenet small: {mobilenet_small_params}")

number of parameters in movinet: 4660762
number of parameters in mobilenet large: 4191584
number of parameters in mobilenet small: 1672816
