### Libraries

In [9]:
#Essentials
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Subset, TensorDataset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import os
from umap import UMAP

#User libraries
from BatchEffectDataLoader import DataPreprocess, DataTransform

### Autoencoder class

In [16]:
class Autoencoder(nn.Module):
    def __init__(self, d_z = 10, input_size = 24, batch_size = 2):
        super().__init__()
        self.d_z = d_z
        self.input_size = input_size
        self.batch_size = batch_size
        self.encoder = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.Linear(128, 64),
            nn.Linear(64, d_z)
        )

        self.decoder = nn.Sequential(
            nn.Linear(d_z, 64),
            nn.Linear(64, 128),
            nn.Linear(128, input_size)
        )

    def forward(self, x):
        x = x.view(-1, self.input_size)
        z = self.encoder(x)
        decoded = self.decoder(z)
        return decoded.view(-1, self.input_size)
    
    def encode(self, x):
        x = x.view(-1, self.input_size)
        return self.encoder(x)
    
    def decode(self, z):
        return self.decoder(z)

### Train model function

In [20]:
def train_model(
        model,
        train_loader,
        criterion,
        optimizer,
        num_epochs,
        device,
        val_loader = None,
        test_loader = None,
        model_name = "model",
        save_model = False
):
    train_losses = []
    test_losses = []
    val_losses = []
    lowest_val_loss = float('inf')
    best_epoch = -1
    best_model_state = None

    for epoch in range(num_epochs):
        #Training
        model.train()
        train_loss = 0.0
        for x in train_loader:
            #Forward pass
            x = x.to(device)
            out = model(x)

            #Backpropagation
            loss = criterion(out, x)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            #Save loss values
            train_loss += loss.item()
        train_loss /= len(train_loader)
        train_losses.append(train_loss)

        #Validation
        model.eval()
        if isinstance(val_loader, DataLoader):
            val_loss = 0.0
            with torch.no_grad():
                for x in val_loader:
                    #Forward pass
                    x = x.to(device)
                    out = model(x)

                    #Save loss values
                    loss = criterion(out, x)
                    val_loss += loss.item()
                
            val_loss /= len(val_loader)
            val_losses.append(val_loss)

            if val_loss < lowest_val_loss:
                lowest_val_loss = val_loss
                best_epoch = epoch
                best_model_state = model.state_dict()
            print(f"Epoch {epoch + 1}/{num_epochs} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
        
        else:
            print(f"Epoch {epoch + 1}/{num_epochs} | Train Loss: {train_loss:.4f}")

    #Testing
    if isinstance(test_loader, DataLoader):
        test_loss = 0.0
        with torch.no_grad():
            for x in test_loader:
                #Forward pass
                x = x.to(device)
                out = model(x)

                #Save loss values
                loss = criterion(out, x)
                test_loss += loss.item()
        
        test_loss /= len(test_loader)
        test_losses.append(test_loss)
        print(f"Test Loss: {test_loss:.4f}")

    if save_model == True:
        script_path = os.getcwd()
        dest = os.path.join(script_path, f'model_weights/{model_name}_best.pth')
        if not os.path.exists(os.path.dirname(dest)):
            os.makedirs(os.path.dirname(dest), exist_ok=True)
        torch.save(best_model_state, dest)

    return train_losses, val_losses, test_losses, best_epoch

### Dataloader

In [21]:
#Load and preprocess data
path = "data/dataset_sponge.csv"
data = DataTransform(DataPreprocess(path))

#Convert data to tensor (desired structure: tensor([otus], [tissue], [batch]))
otu_data = data.select_dtypes(include = "number")
otu_tensor = torch.tensor(otu_data.values, dtype = torch.float32)
ohe_tissue = data["tissue"]
ohe_batch = data["batch"]

# dataloader = DataLoader(TensorDataset(otu_tensor), batch_size = 32) #this should be the correct way to define it
dataloader = DataLoader(otu_tensor, batch_size = 32)

### Training loop

In [22]:
#Setting up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_epochs=20
print(f"Using {device}")

#Defining autoencoder and optimizer
d_z = 2
autoencoder = Autoencoder(d_z=d_z).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=1e-3, weight_decay=1e-5)

#Training loop
train_losses, val_losses, test_losses, best_epoch = train_model(
    model=autoencoder,
    train_loader=dataloader,
    criterion=criterion,
    optimizer=optimizer,
    num_epochs=num_epochs,
    device=device,
    model_name="autoencoder" 
)

Using cpu
Epoch 1/20 | Train Loss: 78.1985
Epoch 2/20 | Train Loss: 76.4306
Epoch 3/20 | Train Loss: 74.1178
Epoch 4/20 | Train Loss: 70.6355
Epoch 5/20 | Train Loss: 65.9253
Epoch 6/20 | Train Loss: 60.1548
Epoch 7/20 | Train Loss: 53.7175
Epoch 8/20 | Train Loss: 47.3269
Epoch 9/20 | Train Loss: 42.0864
Epoch 10/20 | Train Loss: 39.2228
Epoch 11/20 | Train Loss: 38.6762
Epoch 12/20 | Train Loss: 38.4233
Epoch 13/20 | Train Loss: 37.4208
Epoch 14/20 | Train Loss: 36.1617
Epoch 15/20 | Train Loss: 35.2338
Epoch 16/20 | Train Loss: 34.7897
Epoch 17/20 | Train Loss: 34.7088
Epoch 18/20 | Train Loss: 34.8039
Epoch 19/20 | Train Loss: 34.9119
Epoch 20/20 | Train Loss: 34.9190
