# Importing dependencies and initialization

In [1]:
import os
from dotenv import load_dotenv

load_dotenv()

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

import warnings
warnings.filterwarnings("ignore")

import gdown
import patoolib

from preprocessing import get_data, get_transforms
from monai.data import Dataset, DataLoader

import torch
from unet import FlexUNet, train_epoch, validate_epoch
from metric import AllDiceMetric, SeparateDiceMetric
from monai.losses import DiceLoss

import boto3
from botocore.client import Config
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")


# Loading the dataset

In [None]:
if not os.path.exists('archive.zip') and not os.path.exists('data'):
    # Google Drive file ID
    file_id = '1bz476ATbSduGcyw1UIkOkN5YkYZvurZ5'
    # Destination path where the file will be saved
    destination = 'archive.zip'

    # Construct the download URL
    url = f'https://drive.google.com/uc?id={file_id}'

    # Download the file
    gdown.download(url, destination, quiet=False)
    
    patoolib.extract_archive('archive.zip', outdir='data')
    os.remove('archive.zip')
    

# Data preprocessing

In [None]:
train_files, val_files = get_data()
transforms = get_transforms()

print(f'Train files: {len(train_files)}')
print(f'Test files: {len(val_files)}')

In [None]:
train_ds = Dataset(data=train_files, transform=transforms)
val_ds = Dataset(data=val_files, transform=transforms)  

## Model training

In [None]:
EPOCHS = 100
BATCH_SIZE = 32
LR = 1e-4
WORKERS = 16

model = FlexUNet(model_size=64).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=WORKERS)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=WORKERS)

loss_fn = DiceLoss()

all_dice_metric = AllDiceMetric(include_background=True)
separate_dice_metric = SeparateDiceMetric()

model_id = 'UNET_64_Teacher_01'
file_id = model_id + '.pth'


In [None]:
for epoch in range(EPOCHS):
    loss = train_epoch(model, train_loader, optimizer, device)
    print(f'Epoch {epoch+1}/{EPOCHS}, Dice Loss: {loss}')
    total, tumor, enhancing, core = validate_epoch(model, val_loader, all_dice_metric, separate_dice_metric, device, all_dice_metric, separate_dice_metric)
    print(f'Total Dice: {total}, Tumor Dice: {tumor}, Enhancing Dice: {enhancing}, Core Dice: {core}')
    print('-----------------------------------')

torch.save(model.state_dict(), "UNET_64_Teacher_01.pth")

## Model archiving

In [None]:
# R2 credentials and endpoint
access_key = os.environ.get('R2_KEY')
secret_key = os.environ.get('R2_SECRET')
endpoint_url = os.environ.get('R2_ENDPOINT')

bucket_name = os.environ.get('R2_BUCKET')

# Initialize the S3 client
s3 = boto3.client('s3',
				  endpoint_url=endpoint_url,
				  aws_access_key_id=access_key,
				  aws_secret_access_key=secret_key,
				  config=Config(signature_version='s3v4'))

# Function to upload file with progress bar
def upload_file_with_progress(file_path, bucket_name, object_name):
    file_size = os.path.getsize(file_path)
    progress = tqdm(total=file_size, unit='B', unit_scale=True, desc='Uploading ' + object_name)

    def upload_progress(chunk):
        progress.update(chunk)

    s3.upload_file(file_path, bucket_name, object_name, Callback=upload_progress)
    progress.close()

# object details
object_name = file_id
file_path = file_id

# Upload the file
upload_file_with_progress(file_path, bucket_name, object_name)