# Importing dependencies and initialization

In [1]:
import os
import json
from dotenv import load_dotenv

load_dotenv()

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

import warnings
warnings.filterwarnings("ignore")

import gdown
import patoolib

from punkd.preprocessing import get_data, get_transforms
from monai.data import Dataset, DataLoader

import torch
from punkd.unet import FlexUNet, train_epoch, validate_epoch
from monai.losses import DiceLoss
from monai.metrics import DiceMetric

import pandas as pd

import boto3
from botocore.client import Config
from tqdm.notebook import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device == 'cuda':
    torch.cuda.empty_cache()

print(f"Device: {device}")


Device: cuda


# Load Configurations

In [2]:
# load configuration from config.json file in the root directory
config = None
with open('config.json', 'r') as f:
    config = json.load(f)

{'EPOCHS': 100, 'MODEL_SIZE': 64, 'TRAIN_BATCH_SIZE': 8, 'VAL_BATCH_SIZE': 2, 'CPU_WORKERS': 16, 'LEARNING_RATE': 0.0001}


# Loading the dataset

In [None]:
if not os.path.exists('archive.zip') and not os.path.exists('data'):
    # Google Drive file ID
    file_id = '1bz476ATbSduGcyw1UIkOkN5YkYZvurZ5'
    # Destination path where the file will be saved
    destination = 'archive.zip'

    # Construct the download URL
    url = f'https://drive.google.com/uc?id={file_id}'

    # Download the file
    gdown.download(url, destination, quiet=False)
    
    patoolib.extract_archive('archive.zip', outdir='data')
    os.remove('archive.zip')

if os.path.exists('archive.zip'):
    patoolib.extract_archive('archive.zip', outdir='data')
    os.remove('archive.zip')

if os.path.exists('data/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_355/W39_1998.09.19_Segm.nii'):
    # rename the file
    os.rename('data/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_355/W39_1998.09.19_Segm.nii', 
                'data/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_355/BraTS20_Training_355_seg.nii')
    

# Data preprocessing

In [None]:
train_files, val_files = get_data()
transforms = get_transforms()

print(f'Train files: {len(train_files)}')
print(f'Test files: {len(val_files)}')

In [None]:
train_ds = Dataset(data=train_files, transform=transforms)
val_ds = Dataset(data=val_files, transform=transforms)  

## Model training

In [None]:
EPOCHS = config['EPOCHS']
BATCH_SIZE_TRAIN = config['TRAIN_BATCH_SIZE']
BATCH_SIZE_VAL = config['VAL_BATCH_SIZE']

LR = config['LEARNING_RATE']
WORKERS = config['CPU_WORKERS']

MODEL_SIZE = config['MODEL_SIZE']

model = FlexUNet(model_size=MODEL_SIZE).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE_TRAIN, shuffle=True, num_workers=WORKERS)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE_VAL, shuffle=False, num_workers=WORKERS)

loss_fn = DiceLoss().to(device)
score_fn = DiceMetric(reduction='mean_batch')

model_id = 'UNET_4_Teacher_01'
model_file = model_id + '.pth'
csv_file = model_id + '.csv'

performance = pd.DataFrame(columns=['epoch', 'train_loss', 'train_dice', 'tumor_dice', 'enhancing_dice', 'core_dice'])

In [None]:
with tqdm(total=EPOCHS, unit='epoch', position=0) as pbar:
    pbar.set_description(f'EPOCH {0}')
    for epoch in range(EPOCHS):
        loss = train_epoch(model, train_loader, loss_fn, optimizer, device)
        total, (background, tumor, enhancing, core) = validate_epoch(model, val_loader, score_fn, device)
        pbar.set_description(f'EPOCH {epoch+1}')
        pbar.set_postfix({'Dice Loss': loss, 'Total Score': total, 'Tumor Score': tumor, 'Enhancing Score': enhancing, 'Core Score': core})
        perf = pd.DataFrame({'epoch': [epoch+1], 'train_loss': [loss], 'train_dice': [total], 'tumor_dice': [tumor],
                                          'enhancing_dice': [enhancing], 'core_dice': [core]})
        performance = pd.concat([performance, perf], ignore_index=True)
        pbar.update()

torch.save(model.state_dict(), model_file)
performance.to_csv(csv_file, index=False)

## Model archiving

In [None]:
# R2 credentials and endpoint
access_key = os.environ.get('R2_KEY')
secret_key = os.environ.get('R2_SECRET')
endpoint_url = os.environ.get('R2_ENDPOINT')

bucket_name = os.environ.get('R2_BUCKET')

# Initialize the S3 client
s3 = boto3.client('s3',
				  endpoint_url=endpoint_url,
				  aws_access_key_id=access_key,
				  aws_secret_access_key=secret_key,
				  config=Config(signature_version='s3v4'))

# Function to upload file with progress bar
def upload_file_with_progress(file_path, bucket_name, object_name):
    file_size = os.path.getsize(file_path)
    progress = tqdm(total=file_size, unit='B', unit_scale=True, desc='Uploading ' + object_name)

    def upload_progress(chunk):
        progress.update(chunk)

    s3.upload_file(file_path, bucket_name, object_name, Callback=upload_progress)
    progress.close()


# Upload the file
upload_file_with_progress(model_file, bucket_name, model_file)
upload_file_with_progress(csv_file, bucket_name, csv_file)
