## Dependencies

In [1]:
import os
import torch
from torch.utils.data import DataLoader
import torchvision.transforms.functional as TF
from src_3d.data_loader import EchoNetDataset
from src_3d.utils import visualize_random_video_from_loader
from src.model11 import MobileNetV3UNet3D
from src_3d.train import train_UNet3D_weak_supervision

# Global parameters
torch.manual_seed(42)
T = 16 # video length
batch_size=4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr = 1e-3
num_epochs = 500
patience=50

data_path = r"C:\Projects\python\echoframe\data\EchoNet-Dynamic\EchoNet-Dynamic"

model_path_1 = r'./models/pretrained_mobilenet_3d.pt'
log_path_1 = r'./logs/pretrained_mobilenet_3d.csv'
model_path_2 = r'./models/pretrained_masked_mobilenet_3d.pt'
log_path_2 = r'./logs/masked_autoencoder.csv'
model_path_3 = r'./models/scratch_mobilenet_3d.pt'
log_path_3 = r'./logs/scratch_mobilenet_3d.csv'


## Load Data

In [2]:
train_data = EchoNetDataset(root=data_path,
                              split='train',
                              length=T)
train_loader = DataLoader(dataset=train_data,
                          batch_size=batch_size,
                          shuffle=True)

val_data = EchoNetDataset(root=data_path,
                          split='val',
                          length=T)
val_loader = DataLoader(dataset=val_data,
                        batch_size=batch_size)

test_data = EchoNetDataset(root=data_path,
                           split='test',
                           length=T)
test_loader = DataLoader(dataset=test_data,
                         batch_size=4)

[EchoNetDataset] Final usable videos: 7460
[EchoNetDataset] Final usable videos: 1288
[EchoNetDataset] Final usable videos: 1276


## Train

### Model 1 - Pretrained MobileNet3D

In [None]:
model_1 = MobileNetV3UNet3D()
model_1.encoder.load_state_dict(torch.load("./models/pretrained_encoder.pt"))

train_UNet3D_weak_supervision(model=model_1,
                              train_loader=train_loader,
                              valid_loader=val_loader,
                              device=device,
                              num_epochs=num_epochs,
                              lr=lr,
                              log_path=log_path_1,
                              model_path=model_path_1,
                              patience=patience)

del model_1
torch.cuda.empty_cache()


Epoch 1 [Training]:  75%|███████▌  | 1405/1865 [08:38<02:45,  2.78it/s]

### Model 2 - Pretrained Masked MobileNet 3D

In [None]:
model_2 = MobileNetV3UNet3D()
model_2.encoder.load_state_dict(torch.load("./models/pretrained_masked_encoder.pt"))

train_UNet3D_weak_supervision(model=model_2,
                              train_loader=train_loader,
                              valid_loader=val_loader,
                              device=device,
                              num_epochs=num_epochs,
                              lr=lr,
                              log_path=log_path_2,
                              model_path=model_path_2,
                              patience=patience)

del model_2
torch.cuda.empty_cache()

### Model 3 - Mobilenet 3D - scratch

In [None]:
model_3 = MobileNetV3UNet3D()

train_UNet3D_weak_supervision(model=model_3,
                              train_loader=train_loader,
                              valid_loader=val_loader,
                              device=device,
                              num_epochs=num_epochs,
                              lr=lr,
                              log_path=log_path_3,
                              model_path=model_path_3,
                              patience=patience)

del model_3
torch.cuda.empty_cache()