In [None]:
## Load the library: 

from torch.utils.data import Dataset, DataLoader
from model_creation_torch import *
from tensorflow.keras.callbacks import EarlyStopping
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Data Creation

In [None]:
# Create the data:

!rm data/spd2spd/*
!python make_data_spd2spd.py  10000 5 20 10
!python make_data_spd2spd.py  10000 5 40 20
!python make_data_spd2spd.py  10000 20 40 20


# Training

In [18]:
# Delete old models for retrained

!rm trained_models/*
!rm training_history/*

# Hyperparameters

ns = [10000, 1000, 100] 
# ns = [100000]
contexts = ['10000_20_10_5_0', '10000_40_20_5_0', '10000_40_20_20_0']
# contexts = ['10000_20_10_5_0']
batch_size = 32
epochs = 100
lr = 0.001
skip_training = False


# Custom Dataset

class CustomDataset(Dataset):
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x = self.X[idx]
        y = self.Y[idx]
        return x, y
    
# Early stopping:

class EarlyStopping:
    def __init__(self, tolerance=5, min_delta = 0.0001):
        self.min_delta = min_delta
        self.tolerance = tolerance
        self.counter = 0
        self.early_stop = False
        self.previous_loss = float("inf")

    def __call__(self, train_loss):
        if self.previous_loss < train_loss + self.min_delta:
            self.counter +=1
            if self.counter >= self.tolerance:  
                self.early_stop = True
        else:
            self.previous_loss = train_loss
            self.counter = 0
    
# Training and testing loop:

def tr_train_loop(dataloader, model, loss_fn, optimizer): # Training loop for trace regression model
    num_batches = len(dataloader)
    size = len(dataloader.dataset)
    train_loss = 0.0

    for batch, (X, y) in enumerate(dataloader):
        sq_loss = 0.0
        # Compute prediction and loss
        pred = model(X)
        _, S, _ = torch.linalg.svd( model.weight) #(*, m,n) where (*) is batches dimension
        loss = loss_fn(pred, y) + S.sum() * 0.001

        with torch.no_grad():
            sq_loss = loss_fn(pred, y)
            train_loss += sq_loss
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = sq_loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

    return train_loss / num_batches

def train_loop(dataloader, model, loss_fn, optimizer):
    num_batches = len(dataloader)
    size = len(dataloader.dataset)
    train_loss = 0.0

    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        with torch.no_grad():
            train_loss += loss.item()
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

    return train_loss / num_batches


def test_loop(dataloader, model, loss_fn):
    num_batches = len(dataloader)
    test_loss=  0.0

    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()

    test_loss /= num_batches
    print(f"Test Error: \n  Avg loss: {test_loss:>8f} \n")
    return test_loss

# Utils function for saving:

def create_history_df():
    history = pd.DataFrame(columns= ['Epochs', 'TrainLoss', 'TestLoss'])
    return history

def add_record_to_dataframe(df, record):
    # Convert the record tuple to a DataFrame with a single row
    new_row = pd.DataFrame([record], columns=df.columns)
    
    # Concatenate the new DataFrame with the original DataFrame
    df = pd.concat([df, new_row], ignore_index=True)
    
    return df

def save_dataframe_to_pickle(df, n, context, name):
    filepath = f'training_history/{name}_{context}_{n}.pickle'
    df.to_pickle(filepath)
    return filepath
    

def save_model(model, n, context, name):
    filepath = f'trained_models/{name}_{context}_{n}.pt'
    torch.save(model.state_dict(), filepath)
    return filepath



# Loop

if not skip_training:
    for n in ns:
        for context in contexts:
            rank = int(context.split('_')[3])
            X = np.load('data/spd2spd/X_'+context+'.npy')
            Y = np.load('data/spd2spd/Y_'+context+'.npy')
            Xt = np.load('data/spd2spd/Xt_'+context+'.npy')
            Yt = np.load('data/spd2spd/Yt_'+context+'.npy')
    
            # Create training dataset and dataloader
            train_dataset = CustomDataset(X[0:n], Y[0:n])
            train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
            
            # Create test dataset and dataloader
            test_dataset = CustomDataset(Xt, Yt)
            test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
            # Multivariate Linear Regression:
            earlystop = EarlyStopping(min_delta=0.01, tolerance= 3)
            history = create_history_df()
            model = MVl(X[0].shape, Y[0].shape)
            loss_fn = nn.MSELoss()
            optimizer = torch.optim.Adam(model.parameters(), lr=lr)
            
            
            for t in range(epochs):
                print(f"Epoch {t+1}\n-------------------------------")
                train_loss = train_loop(train_dataloader, model, loss_fn, optimizer)
                test_loss = test_loop(test_dataloader, model, loss_fn)
                history = add_record_to_dataframe(history, (t, train_loss, test_loss))
                earlystop(train_loss)
                if earlystop.early_stop:
                    break
    
            save_model(model, n, context, "MVL")
            save_dataframe_to_pickle(history, n, context, "MVL")
    
            # Partial Trace Regression:
            earlystop = EarlyStopping(min_delta=0.00001, tolerance= 3)
            history = create_history_df()
            output_shape = Y.shape[1]
            input_shape = X[0].shape
            model = KrausLayer(input_shape, output_shape, rank)
            loss_fn = nn.MSELoss()
            optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
            for t in range(epochs):
                print(f"Epoch {t+1}\n-------------------------------")
                train_loss = train_loop(train_dataloader, model, loss_fn, optimizer)
                test_loss = test_loop(test_dataloader, model, loss_fn)
                history = add_record_to_dataframe(history, (t, train_loss, test_loss))
                earlystop(train_loss)
                if earlystop.early_stop:
                    break
    
            save_model(model, n, context, "PTR")
            save_dataframe_to_pickle(history, n, context, "PTR")
    
    
            # Reduced Rank Regression
            earlystop = EarlyStopping(min_delta=0.01, tolerance= 3)
            history = create_history_df()    
            model = RRMVL(rank * 10, X[0].shape, Y[0].shape)
            loss_fn = nn.MSELoss() 
            optimizer = torch.optim.Adam(model.parameters(), lr=lr)
            
            
            for t in range(epochs):
                print(f"Epoch {t+1}\n-------------------------------")
                train_loss = train_loop(train_dataloader, model, loss_fn, optimizer)
                test_loss = test_loop(test_dataloader, model, loss_fn)
                history = add_record_to_dataframe(history, (t, train_loss, test_loss))
                earlystop(train_loss)
                if earlystop.early_stop:
                    break
            save_model(model, n, context, "RRR")
            save_dataframe_to_pickle(history, n, context, "RRR")
    
            # Trace Regression
            earlystop = EarlyStopping(min_delta=0.001, tolerance= 5)
            history = create_history_df()
            model = TraceLayer(X[0].shape, Y[0].shape)
            loss_fn = nn.MSELoss() 
            optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
            for t in range(epochs):
                print(f"Epoch {t+1}\n-------------------------------")
                train_loss = tr_train_loop(train_dataloader, model, loss_fn, optimizer)
                test_loss = test_loop(test_dataloader, model, loss_fn)
                history = add_record_to_dataframe(history, (t, train_loss, test_loss))
                earlystop(train_loss)
                if earlystop.early_stop:
                    break
            save_model(model, n, context, "TR")
            save_dataframe_to_pickle(history, n, context, "TR")

Epoch 1
-------------------------------
loss: 1904.976562  [   32/10000]
loss: 878.144531  [ 3232/10000]
loss: 359.032196  [ 6432/10000]
loss: 153.221695  [ 9632/10000]
Test Error: 
  Avg loss: 145.586749 

Epoch 2
-------------------------------
loss: 156.855286  [   32/10000]
loss: 66.777740  [ 3232/10000]
loss: 39.044132  [ 6432/10000]
loss: 27.328733  [ 9632/10000]
Test Error: 
  Avg loss: 23.769471 

Epoch 3
-------------------------------
loss: 23.725960  [   32/10000]
loss: 17.750263  [ 3232/10000]
loss: 15.623999  [ 6432/10000]
loss: 12.655618  [ 9632/10000]
Test Error: 
  Avg loss: 12.460037 

Epoch 4
-------------------------------
loss: 11.501343  [   32/10000]
loss: 11.381852  [ 3232/10000]
loss: 8.320279  [ 6432/10000]
loss: 7.430315  [ 9632/10000]
Test Error: 
  Avg loss: 7.375814 

Epoch 5
-------------------------------
loss: 6.673481  [   32/10000]
loss: 5.739625  [ 3232/10000]
loss: 5.105289  [ 6432/10000]
loss: 4.965164  [ 9632/10000]
Test Error: 
  Avg loss: 4.70378

# Evaluating