In [None]:
import torch
from torch import nn
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.tensorboard import SummaryWriter
import models
import training
from torchvision.transforms import v2
import numpy as np
import matplotlib.pyplot as plts
from os.path import join
import torchvision
import torch.nn.functional as F
import h5py
import matplotlib.pyplot as plt
from torch.utils.data import random_split

device = ("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

def format_data(x):
    mean = [x[:, n, :, :].mean() for n in range(x.shape[1])]
    std = [x[:, n, :, :].std() for n in range(x.shape[1])]

    X = v2.Compose([
            torch.from_numpy,
            v2.Normalize(mean=mean, std=std),
            v2.Resize((64, 64)),
        ])(x).to(device)
    return X

In [None]:
orders = range(5,6)

for order in orders:
    with h5py.File('../../Data/Training/mixed_intense.h5', 'r') as f:
        images = format_data(f[f'images_order{order}'][:])
        labels = torch.from_numpy(f[f'labels_order{order}'][:]).to(device)

    dset = TensorDataset(images, labels)

    train_size = int(0.85 * len(dset))
    test_size = len(dset) - train_size

    train_dataset, test_dataset = random_split(dset, [train_size, test_size])
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=64)

    L = images.shape[2]
    loss_fn = torch.nn.MSELoss()
    n_channels = 2
    n_classes = (order+1)**2-1

    #model = models.ConvNet(L,L,n_channels, n_classes,[24,40,35],5,nn.ELU,[120,80,40]).to(device)
    # Create a mobilenet_v3_small model
    
    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), amsgrad=True)
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, factor=.5, min_lr = 1e-5)

    save_path = f"../../Results/MachineLearningModels/Intense/MobileNet/Mixed_Order{order}"

    writer = SummaryWriter(save_path)
    early_stopping = training.EarlyStopping(patience=30,save_path=save_path)
    for t in range(200):
        epoch = t+1
        print(f"-------------------------------\nEpoch {epoch}")
        training.train(model, train_loader, loss_fn, optimizer, device)
        val_loss = training.test(model, test_loader, loss_fn, device, epoch, writer, verbose=True)
        scheduler.step(val_loss)
        early_stopping(val_loss, model)

        if early_stopping.early_stop:
            print("Early stopping")
            break
    print("Done!")
    writer.close()
        

In [None]:
scheduler.get_last_lr()

In [None]:
def crop_indices(width, height, xmin, xmax, ymin, ymax, Xmin, Xmax, Ymin, Ymax):   
   
    # Calculate the cropping coordinates
    lower = int(round(height * (Ymax - ymin) / (ymax - ymin)))
    upper = int(round(height * (Ymin - ymin) / (ymax - ymin)))
    left = int(round(width * (Xmin - xmin) / (xmax - xmin)))
    right = int(round(width * (Xmax - xmin) / (xmax - xmin)))
    
    return max(upper, 1),min(lower,height),max(left,1),min(right,width)

In [None]:
with h5py.File('../../Data/Processed/mixed_intense.h5') as f:
    direct_lims = f['direct_lims'][:]
    converted_lims = f['converted_lims'][:]
    images = f['images_order1'][:]

In [None]:
upper_d,lower_d,left_d,right_d = crop_indices(400,400, *direct_lims, -3,3,-3,3)

plt.imshow(images[3,0,upper_d:lower_d,left_d:right_d])

In [None]:
orders = range(1,6)

with h5py.File('../../Data/Processed/mixed_intense.h5') as f:
    direct_lims = f['direct_lims'][:]
    converted_lims = f['converted_lims'][:]

    with h5py.File("../../Results/Intense/machine_learning.h5", 'w') as out:

        for order in orders:
            R = 2.6 + 0.6*order
            upper_d,lower_d,left_d,right_d = crop_indices(400,400, *direct_lims, -R,R,-R,R)
            upper_c,lower_c,left_c,right_c = crop_indices(400,400, *converted_lims, -R,R,-R,R)

            transform = v2.Compose([
                torch.from_numpy,
                v2.Resize((64, 64))])

            direct = transform(f[f'images_order{order}'][:,0,upper_d:lower_d,left_d:right_d]).float()
            converted = transform(f[f'images_order{order}'][:,1,upper_c:lower_c,left_c:right_c]).float()
            
            def normalize(x):
                mean = [x[:, n, :, :].mean() for n in range(x.shape[1])]
                std = [x[:, n, :, :].std() for n in range(x.shape[1])]
                return v2.Normalize(mean=mean, std=std)(x)


            images_exp = normalize(torch.stack((direct,converted),1)).to(device)
            labels_exp = f[f'labels_order{order}'][:]
        

            model = torch.load(f"../../Results/MachineLearningModels/Intense/Mixed_Order{order}/checkpoint.pt")

            with torch.no_grad():
                labels_pred = model(images_exp).cpu().numpy()
            
                out.create_dataset(f'pred_labels_order{order}', data=labels_pred)
                out.create_dataset(f'labels_order{order}', data=labels_exp)