## Dependencies

In [1]:
import os
import torch
from torch.utils.data import DataLoader
import torchvision.transforms.functional as TF
from src_3d.data_loader import EchoNetDataset
from src_3d.utils import visualize_random_video_from_loader, test_3d_unet, visualize_clip_with_overlay, save_overlay_gif_from_loader
# from src.model11 import MobileNetV3UNet3D
from src_3d.train1 import train_UNet3D_weak_supervision
from src_3d.model1 import MobileNetV3UNet3D

# Global parameters
torch.manual_seed(42)
T = 16 # video length
batch_size=4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr = 1e-3
num_epochs = 500
patience=10

data_path = r"C:\Projects\python\echoframe\data\EchoNet-Dynamic\EchoNet-Dynamic"

model_path_1 = r'./models/pretrained_mobilenet_3d.pt'
log_path_1 = r'./logs/pretrained_mobilenet_3d.csv'
model_path_2 = r'./models/pretrained_masked_mobilenet_3d.pt'
log_path_2 = r'./logs/masked_autoencoder.csv'
model_path_3 = r'./models/scratch_mobilenet_3d.pt'
log_path_3 = r'./logs/scratch_mobilenet_3d.csv'


## Load Data

In [2]:
train_data = EchoNetDataset(root=data_path,
                              split='train',
                              length=T)
train_loader = DataLoader(dataset=train_data,
                          batch_size=batch_size,
                          shuffle=True)

val_data = EchoNetDataset(root=data_path,
                          split='val',
                          length=T)
val_loader = DataLoader(dataset=val_data,
                        batch_size=batch_size)

test_data = EchoNetDataset(root=data_path,
                           split='test',
                           length=T)
test_loader = DataLoader(dataset=test_data,
                         batch_size=4)

[EchoNetDataset] Final usable videos: 7460
[EchoNetDataset] Final usable videos: 1288
[EchoNetDataset] Final usable videos: 1276


## Train

### Model 1 - Pretrained MobileNet3D

In [None]:
model_1 = MobileNetV3UNet3D()
# model_1.encoder.load_state_dict(torch.load("./models/pretrained_encoder.pt"))

# train_UNet3D_weak_supervision(model=model_1,
#                               train_loader=train_loader,
#                               valid_loader=val_loader,
#                               device=device,
#                               num_epochs=num_epochs,
#                               lr=lr,
#                               log_path=log_path_1,
#                               model_path=model_path_1,
#                               patience=patience)

                                                                       

Epoch 1/500 | Train Loss: 0.2412 | Valid Loss: 0.1783 | Valid Dice: 0.8874 | Patience: 0


                                                                       

Epoch 2/500 | Train Loss: 0.1779 | Valid Loss: 0.1652 | Valid Dice: 0.8933 | Patience: 0


                                                                       

Epoch 3/500 | Train Loss: 0.1695 | Valid Loss: 0.1743 | Valid Dice: 0.8904 | Patience: 1


                                                                       

Epoch 4/500 | Train Loss: 0.1619 | Valid Loss: 0.1488 | Valid Dice: 0.9049 | Patience: 0


                                                                       

Epoch 5/500 | Train Loss: 0.1555 | Valid Loss: 0.1492 | Valid Dice: 0.9040 | Patience: 1


                                                                       

Epoch 6/500 | Train Loss: 0.1529 | Valid Loss: 0.1535 | Valid Dice: 0.9019 | Patience: 2


                                                                       

Epoch 7/500 | Train Loss: 0.1530 | Valid Loss: 0.1419 | Valid Dice: 0.9091 | Patience: 0


                                                                       

Epoch 8/500 | Train Loss: 0.1490 | Valid Loss: 0.1427 | Valid Dice: 0.9083 | Patience: 1


                                                                       

Epoch 9/500 | Train Loss: 0.1476 | Valid Loss: 0.1440 | Valid Dice: 0.9078 | Patience: 2


                                                                        

Epoch 10/500 | Train Loss: 0.1439 | Valid Loss: 0.1401 | Valid Dice: 0.9097 | Patience: 0


                                                                        

Epoch 11/500 | Train Loss: 0.1434 | Valid Loss: 0.1407 | Valid Dice: 0.9094 | Patience: 1


                                                                        

Epoch 12/500 | Train Loss: 0.1404 | Valid Loss: 0.1366 | Valid Dice: 0.9118 | Patience: 0


                                                                        

Epoch 13/500 | Train Loss: 0.1395 | Valid Loss: 0.1425 | Valid Dice: 0.9081 | Patience: 1


                                                                        

Epoch 14/500 | Train Loss: 0.1386 | Valid Loss: 0.1380 | Valid Dice: 0.9113 | Patience: 2


                                                                        

Epoch 15/500 | Train Loss: 0.1365 | Valid Loss: 0.1351 | Valid Dice: 0.9128 | Patience: 0


                                                                        

Epoch 16/500 | Train Loss: 0.1372 | Valid Loss: 0.1341 | Valid Dice: 0.9134 | Patience: 0


                                                                        

Epoch 17/500 | Train Loss: 0.1345 | Valid Loss: 0.1490 | Valid Dice: 0.9054 | Patience: 1


Epoch 18 [Training]:  93%|█████████▎| 1737/1865 [10:30<00:48,  2.62it/s]

#### Test

In [None]:
test_3d_unet(model=model_1,
             test_loader=test_loader,
             model_path=model_path_1,
             device=device)

visualize_clip_with_overlay(model=model_1,
                            test_loader=test_loader,
                            model_path=model_path_1,
                            device=device)
save_overlay_gif_from_loader(model=model_1,
                               test_loader=test_loader,
                               model_path=model_path_1,
                               save_path='./assets/annotated_echocardigram_1.gif',
                               batch_idx=0,
                               sample_idx=0)

del model_1
torch.cuda.empty_cache()

✅ Saved overlay GIF to: ./assets/annotated_echocardigram_1.gif


### Model 2 - Pretrained Masked MobileNet 3D

In [None]:
model_2 = MobileNetV3UNet3D()
model_2.encoder.load_state_dict(torch.load("./models/pretrained_masked_encoder.pt"))

train_UNet3D_weak_supervision(model=model_2,
                              train_loader=train_loader,
                              valid_loader=val_loader,
                              device=device,
                              num_epochs=num_epochs,
                              lr=lr,
                              log_path=log_path_2,
                              model_path=model_path_2,
                              patience=patience)


In [None]:
test_3d_unet(model=model_2,
             test_loader=test_loader,
             model_path=model_path_2,
             device=device)

visualize_clip_with_overlay(model=model_2,
                            test_loader=test_loader,
                            model_path=model_path_2,
                            device=device)
save_overlay_gif_from_loader(model=model_2,
                               test_loader=test_loader,
                               model_path=model_path_2,
                               save_path='./assets/annotated_echocardigram_2.gif',
                               batch_idx=0,
                               sample_idx=0)

del model_2
torch.cuda.empty_cache()

### Model 3 - Mobilenet 3D - scratch

In [None]:
model_3 = MobileNetV3UNet3D()

train_UNet3D_weak_supervision(model=model_3,
                              train_loader=train_loader,
                              valid_loader=val_loader,
                              device=device,
                              num_epochs=num_epochs,
                              lr=lr,
                              log_path=log_path_3,
                              model_path=model_path_3,
                              patience=patience)

In [None]:
test_3d_unet(model=model_3,
             test_loader=test_loader,
             model_path=model_path_3,
             device=device)

visualize_clip_with_overlay(model=model_3,
                            test_loader=test_loader,
                            model_path=model_path_3,
                            device=device)
save_overlay_gif_from_loader(model=model_3,
                               test_loader=test_loader,
                               model_path=model_path_3,
                               save_path='./assets/annotated_echocardigram_3.gif',
                               batch_idx=0,
                               sample_idx=0)

del model_3
torch.cuda.empty_cache()