In [2]:
import sys
sys.path.append('..')

import torch
from torch.utils.data import DataLoader

from src.train import CombinedLoss, train, MedicalScanDataset
from src.models.transunet.TransAttUnet import UNet_Attention_Transformer_Multiscale
from src.models.unet import UNet
from src.evaluate import evaluate_model

NUM_CLASSES = 8
BATCH_SIZE = 16
EPOCH_COUNT = 15
INITIAL_LR = 1e-4

#### Initialize data

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load training and validation data
data_dir = "pack/processed_data"
train_data = MedicalScanDataset(f'{data_dir}/ct_256/train/npz/')

valid_data = MedicalScanDataset(f'{data_dir}/ct_256/val/npz/')
test_data = MedicalScanDataset(f'{data_dir}/ct_256/test/npz/')

train_loader = DataLoader(
    train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=0
)
valid_loader = DataLoader(
    valid_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=0
)
test_loader = DataLoader(
    test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=0
)

### Load and train model

In [None]:
# Load model
model = UNet_Attention_Transformer_Multiscale(1, NUM_CLASSES, output_attention=False)
model.to(device)

# Train model
loss_function = CombinedLoss(n_classes=NUM_CLASSES)
optimizer = torch.optim.Adam(model.parameters(), lr=INITIAL_LR)
model = train(
    model, 
    loss_function, 
    optimizer, 
    train_loader, 
    valid_loader, 
    test_loader, 
    EPOCH_COUNT, 
    NUM_CLASSES, 
    p_adversial=0.3, 
    output_attention=False
)

### Test model

In [None]:
# Test model
odel = UNet(input_channels=1, num_classes=NUM_CLASSES).to(device)
wpth = 'weights/model_weights.pth'
model.load_state_dict(torch.load(wpth))
model.to(device)
model.eval()
loss_function = CombinedLoss(n_classes=NUM_CLASSES)
test_dice, test_loss = evaluate_model(
    model, 
    test_loader, 
    loss_function, 
    NUM_CLASSES, 
    adversarial='DAG', 
    epsilon=0.01, 
    adv_iterations=5
)
print(f'Test Loss: {test_loss:.4f} | Test Dice: {test_dice:.4f}')