In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from sklearn.model_selection import KFold, train_test_split
from torch.utils.data import DataLoader, TensorDataset

from model import GeneExpressionPredictionModel

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
# Load data
fp_data = pd.read_csv('data/fingerprints.csv').values
cp_data = pd.read_csv('data/descriptors.csv').values
gene_data = pd.read_csv('data/genes.csv').values

print(fp_data.shape, cp_data.shape, gene_data.shape)

In [None]:
# Split data
fp_train, fp_test, cp_train, cp_test, gene_train, gene_test = train_test_split(
    fp_data, cp_data, gene_data, test_size=0.2, random_state=42
)

In [None]:
# Data loaders
def get_data_loader(fp, cp, gene, batch_size):
    dataset = TensorDataset(
        torch.tensor(fp, dtype=torch.float32).to(device),
        torch.tensor(cp, dtype=torch.float32).to(device),
        torch.tensor(gene, dtype=torch.float32).to(device)
    )
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
# Calculate mean PCC
def get_mean_pcc(model, data_loader):
    model.eval()
    with torch.inference_mode():
        pccs = []
        for fp, cp, gene in data_loader:
            output = model(fp, cp)
            pcc = np.corrcoef(output.cpu().detach().numpy().flatten(), gene.cpu().detach().numpy().flatten())[0, 1]
            pccs.append(pcc)

        return np.mean(pccs)

In [None]:
def initialize_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)

In [None]:
# Initialize
model = GeneExpressionPredictionModel().to(device)
model.apply(initialize_weights)

train_pccs, train_losses, cv_pccs, test_pccs = [], [], [], []

In [None]:
# Training func
def train(model, train_loader, cv_loader, test_loader, optimizer, loss_func, epochs=10, patience=20):
    best_model_wts = model.state_dict()
    best_loss = np.inf
    epochs_no_improve = 0

    epoch_losses, epoch_train_pccs, epoch_val_pccs, epoch_test_pccs = [], [], [], []

    for epoch in range(epochs):
        model.train()

        total_loss = .0

        for fp, cp, gene in train_loader:
            optimizer.zero_grad()

            with torch.autocast(device_type=device.type):
                output = model(fp, cp)
                loss = loss_func(output, gene)

            loss.backward()
            optimizer.step()
            total_loss += loss.item() * fp.size(0)

        epoch_loss = total_loss / len(train_loader.dataset)
        epoch_losses.append(epoch_loss)

        model.eval()
        with torch.inference_mode():
            train_pcc = get_mean_pcc(model, train_loader)
            epoch_train_pccs.append(train_pcc)

            cv_pcc = get_mean_pcc(model, cv_loader)
            epoch_val_pccs.append(cv_pcc)

            test_pcc = get_mean_pcc(model, test_loader)
            epoch_test_pccs.append(test_pcc)

        if epoch % 10 == 0:
            print(
                f'Epoch {epoch + 1}/{epochs} - Loss: {epoch_loss:.4f} - Train PCC: {train_pcc:.4f} - Val PCC: {cv_pcc:.4f} - Test PCC: {test_pcc:.4f}')

        if best_loss - epoch_loss > 1e-3:
            best_loss = epoch_loss
            best_model_wts = model.state_dict()
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= patience:
            print(f'Early stopping at epoch {epoch + 1}')
            break

    train_losses.append(epoch_losses)
    train_pccs.append(epoch_train_pccs)
    cv_pccs.append(epoch_val_pccs)
    test_pccs.append(epoch_test_pccs)

    model.load_state_dict(best_model_wts)
    # torch.save(model.state_dict(), 'model.pt')

    return model

In [None]:
# KFold
kf = KFold(n_splits=5, shuffle=True, random_state=42)

In [None]:
# Train
for train_index, cv_index in kf.split(fp_train):
    batch_size = 16

    train_loader = get_data_loader(fp_train[train_index], cp_train[train_index], gene_train[train_index],
                                   batch_size=batch_size)
    cv_loader = get_data_loader(fp_train[cv_index], cp_train[cv_index], gene_train[cv_index], batch_size=batch_size)
    test_loader = get_data_loader(fp_test, cp_test, gene_test, batch_size=batch_size)

    model = GeneExpressionPredictionModel().to(device)
    model = train(model, train_loader, cv_loader, test_loader, optimizer=torch.optim.Adam(model.parameters(), lr=1e-5),
                  loss_func=nn.MSELoss(), epochs=200, patience=10)

In [None]:
# Create a figure with four subplots
fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(16, 8))

# Plot train_loss on the first subplot
for i, loss_values in enumerate(train_losses):
    ax1.plot(range(1, len(loss_values) + 1), loss_values, label=f'Fold {i + 1}')
ax1.set_xlabel('Epochs')
ax1.set_yscale('log')
ax1.set_ylabel('Loss')
ax1.set_title('Train loss over Epochs for each Fold')
ax1.legend()

# Plot train_pcc on the second subplot
for i, pcc_values in enumerate(train_pccs):
    ax2.plot(range(1, len(pcc_values) + 1), pcc_values, label=f'Fold {i + 1}')
ax2.set_xlabel('Epochs')
ax2.set_ylabel('PCC')
ax2.set_title('Train PCC over Epochs for each Fold')
ax2.legend()

# Plot val_pcc on the third subplot
for i, pcc_values in enumerate(cv_pccs):
    ax3.plot(range(1, len(pcc_values) + 1), pcc_values, label=f'Fold {i + 1}')
ax3.set_xlabel('Epochs')
ax3.set_ylabel('PCC')
ax3.set_title('Validation PCC over Epochs for each Fold')
ax3.legend()

# Plot test_pcc on the fourth subplot
for i, pcc_values in enumerate(test_pccs):
    ax4.plot(range(1, len(pcc_values) + 1), pcc_values, label=f'Fold {i + 1}')
ax4.set_xlabel('Epochs')
ax4.set_ylabel('PCC')
ax4.set_title('Test PCC over Epochs for each Fold')
ax4.legend()

# Display the plot
plt.tight_layout()
plt.show()