In [None]:
import os 
import warnings 
warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd

import torch 
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
import optuna

from pathlib import Path
import xarray as xr

import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

In [None]:
## set working direcitory and import project module
HOME_DIR = '/home/jovyan/open_pluto/kelp_forest_detection'
os.chdir(HOME_DIR)
from src.dataset import KelpDataset
from src.model import BaseUnet, UNetPlus, DiceLoss
from src.train import Trainer

In [None]:
# Set variables and parameters 
TRAIN_DIR = os.path.join(HOME_DIR, 'data/training/')
# Data prams
IMAGE_DIR = os.path.join(TRAIN_DIR,'train_satellite')
MASK_DIR = os.path.join(TRAIN_DIR,'train_kelp')

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Model params
WORKER = 2
CHANNEL_IN = 7
CHANNEL_OUT = 1
BATCH_SIZE = 4

## Get Data 

In [None]:
# Get metdatafile for training and test split 
meta_df = pd.read_csv('metadata_kelp.csv')
df = meta_df[meta_df['in_train'] == True].head(2000)
train_df, valid_df = train_test_split(df, test_size = .2, random_state=42)
train_files = train_df['filename'].tolist()
valid_files = valid_df['filename'].tolist()
# Check length
len(train_files), len(valid_files)

In [None]:
torch.manual_seed(12)
transform = transforms.Compose([
    transforms.ToTensor(),
    #transforms.RandomHorizontalFlip(),
    #transforms.Resize((256, 256)),
    
])
# Custome datasets
# Create a custom training custom dataset for training dataset
train_dataset = KelpDataset(image_dir=IMAGE_DIR, mask_dir=MASK_DIR,
                            transform=transform, filename_list=train_files)
# Create a custom training custom dataset for validation set
valid_dataset = KelpDataset(image_dir=IMAGE_DIR, mask_dir=MASK_DIR,
                            transform=transform, filename_list=valid_files)
# Load dataloader for train and validation set
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
                              num_workers=0, pin_memory=True,  prefetch_factor=2)
# Validation loader
valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, 
                              num_workers=0, pin_memory=True,  prefetch_factor=WORKER)

In [None]:
img, mask = next(iter(train_dataloader))
img.shape, mask.shape

In [None]:
plt.imshow(mask[3][0])

In [None]:
element_at_index_1_0 = mask[3][0]

# Flatten the tensor to a 1D array
flattened_values = element_at_index_1_0.flatten()

# Plot the histogram
plt.hist(flattened_values.numpy(), bins=50, color='blue', alpha=0.7)
plt.title('Distribution / Histogram Chart')
plt.xlabel('Values')
plt.ylabel('Frequency')
plt.show()

## Train Model

In [None]:
torch.cuda.device_count()

In [None]:
# Define the objective function for Optuna
def objective(trial):
    # Define hyperparameters to be optimized
    in_channels = 7  # Change if your input channels are different
    out_channels = 1
    #features = [trial.suggest_int(f'features_{i}', 32, 512) for i in range(4)]
    dropout_prob = trial.suggest_float('dropout_prob', 0.0, 0.5)
    # Define model
    model = BaseUnet(
        in_channels=in_channels,
        out_channels=out_channels,
        #features=features,
        dropout_prob=dropout_prob
    )
    
    # Set training hyperparameter
    lr = trial.suggest_float('lr', 1e-5, 1e-2, log=True)
    # Choose optimizer and its hyperparameters
    optimizer_name = trial.suggest_categorical('optimizer', ['Adam', 'Adamax'])
    weight_decay = trial.suggest_float('weight_decay', 1e-6, 1e-2, log=True)
    if optimizer_name == 'Adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    else:
        optimizer = torch.optim.Adamax(model.parameters(), lr=lr, weight_decay=weight_decay)

    # Choose loss function
    loss_name = trial.suggest_categorical('loss', ['MSELoss', 'BCELoss', 'HuberLoss'])
    if loss_name == 'MSELoss':
        criterion = torch.nn.MSELoss()
    elif loss_name == 'HuberLoss':
        criterion = torch.nn.HuberLoss()
    else:
        # Assuming you have a custom DiceLoss implementation
        criterion = torch.nn.BCEWithLogitsLoss()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Instantiate Trainer with Optuna trial
    trainer = Trainer(
        model=model.to(device),
        train_dataloader=train_dataloader,
        valid_dataloader=valid_dataloader,
        criterion=criterion,
        optimizer=optimizer,
        device=device,
        tb_path='tensorboard/runs/',
        checkpoint_path='tune/',
        trial=trial
    )
    
    # Train the model
    trainer.train(epochs=10)  # Adjust the number of epochs as needed

    # Return a scalar value indicating the performance
    return trainer.best_valid_loss

In [None]:
# Create an optimization study and perform optimization
study = optuna.create_study(direction='minimize')  # Optimize for minimum validation loss
study.optimize(objective, n_trials=10)