In [None]:
# libraries
import pandas as pd
from pathlib import Path
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import Dataset

In [None]:
#device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
def robust_scale(x, eps=1e-6):
    median = np.median(x)
    mad = np.median(np.abs(x - median))
    return (x - median) / (mad + eps)

def build_label_vector(
    df_signal,
    df_peaks,
    plant_idx,
    channel,
    multiplex,
    sigma=10,
):
    molw = df_signal["molw"].values
    Y = np.zeros(len(molw), dtype=np.float32)

    dfp = df_peaks[
        (df_peaks["channel"] == channel) &
        (df_peaks["plant_id"] == plant_idx) &
        (df_peaks["multiplex"] == multiplex) &
        (df_peaks["peak_kind"] == "main")
    ]

    for _, row in dfp.iterrows():
        mu = row["mu_pb"]

        left = mu - sigma
        right = mu + sigma

        mask = (molw >= left) & (molw <= right)
        Y[mask] = 1.0

    return Y


class PeakDataset(Dataset):
    def __init__(self, csv_files):
        self.samples = []
        self.labels = []
        for csv_path in csv_files:
            df_signal = pd.read_csv(csv_path, sep=";")

            for ch in ["channel_1", "channel_2", "channel_3",
                       "channel_4"]:
                ch_idx = int(ch.split("_")[1])
                plant_idx = int(csv_path.stem.split("_pl")[-1])
                multiplex = int(csv_path.stem.split("M")[-1].split("_pl")[0])

                df_peaks = pd.read_csv("../Data/synthetic_ce_outputs_v5/peak_positions_detailed.csv",
                    sep=","
                )

                x = df_signal[ch].values.astype(np.float32)
                #print("x.shape",x.shape)
                # check length
                assert len(x) == 4961

                x = robust_scale(x)

                # shape (1, 4969)
                x = torch.from_numpy(x).unsqueeze(0)
                print(plant_idx, ch_idx, multiplex, x.shape)
                y = build_label_vector(
                    df_signal,
                    df_peaks,
                    plant_idx,
                    channel=ch_idx,
                    multiplex=multiplex
                )
                y = torch.from_numpy(y)

                self.samples.append(x)
                self.labels.append(y)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]


In [None]:
from torch.utils.data import DataLoader

csv_files = list(Path("../Data/synthetic_ce_outputs/synthetic_ce_outputs_v9/csv").glob("*.csv"))
#number of files in the folder
print(f"Number of CSV files found: {len(csv_files)}")
dataset = PeakDataset(csv_files)
x_data = dataset.samples
y_data = dataset.labels

print(len(x_data))
print(len(y_data))



In [None]:
print(len(x_data))
print(len(y_data))

print(x_data[0].shape)
print(y_data[0].shape)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
plt.plot(x_data[0][0].numpy())
plt.title("Example Signal from Batch")
plt.show()

In [None]:
print(y_data[0].numpy())
plt.figure(figsize=(10, 6))
y_data = torch.stack(y_data)
#sns.heatmap(y_data.numpy(), cmap="viridis", cbar=True)
plt.title("Heatmap of Signal Batch")
plt.xlabel("Molecular Weight Index")
plt.ylabel("Sample Index in Batch")
plt.show()

In [None]:
#!pip install matplotlib

In [None]:
# train test split
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader

BATCHSIZE= 64

X_train, X_test, y_train, y_test = train_test_split(
    x_data,
    y_data,
    test_size=0.2,
    random_state=42
)
X_train = torch.stack(X_train)        # shape: (N, C, H, W)
#y_train = torch.stack(y_train).float()
train_dataset = TensorDataset(X_train, y_train)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCHSIZE,
    shuffle=True,
    drop_last=True,
    num_workers=4,
    pin_memory=True
)

X_test = torch.stack(X_test)
#y_test = torch.stack(y_test).float()
test_dataset = TensorDataset(X_test, y_test)
test_loader = DataLoader(
    test_dataset,
    batch_size=BATCHSIZE,
    shuffle=False,
    drop_last=False,
    num_workers=4,
    pin_memory=True
)

In [None]:
print(len(X_train))
print(len(y_train))
print(len(X_test))
print(len(y_test))
print(type(X_train))
print(type(X_train[0]))

In [None]:
print(type(y_train))
print(type(y_test))

In [None]:

#visualisation of a batch
for x_batch, y_batch in train_loader:
    print(x_batch.shape)
    print(x_batch[2])
    print(len(x_batch[2][0]))
    break

# Example output:
# torch.Size([16, 1, 4961])

#signal plot
import matplotlib.pyplot as plt
plt.plot(x_batch[2][0].numpy())
plt.title("Example Signal from Batch")
plt.show()

In [None]:
#!pip install seaborn

In [None]:

#visualisation of a batch as a heatmap
import seaborn as sns
plt.figure(figsize=(10, 6))
sns.heatmap(y_batch.squeeze().numpy(), cmap="viridis", cbar=True)
plt.title("Heatmap of Signal Batch")
plt.xlabel("Molecular Weight Index")
plt.ylabel("Sample Index in Batch")
plt.show()

In [None]:
class down_block(nn.Module):
    def __init__(self, in_channel,out_channel,kernel, padding,stride):
        super().__init__()
        self.conv = nn.Conv1d(in_channel, out_channel, kernel_size=kernel, padding=padding, stride=stride)
        self.norm = nn.BatchNorm1d(out_channel)
        self.block = nn.Sequential(
            self.conv,
            self.norm,
            nn.LeakyReLU(0.25)
        )
    
    def forward(self, x):
        return self.block(x)

class up_block(nn.Module):
    def __init__(self,in_channel,out_channel,kernel, padding,stride, apply_Dropout=True):
        super().__init__()
        self.conv = nn.ConvTranspose1d(in_channel, out_channel, kernel_size=kernel, padding=padding, stride=stride)
        self.norm = nn.BatchNorm1d(out_channel*2) #after concatenation
        self.refine = nn.Conv1d(out_channel * 2, out_channel, kernel_size=3, padding=1)
        self.post = nn.Sequential(
            self.refine, #non linearity after concatenation
            self.norm,
            nn.Dropout(0.25) if apply_Dropout else nn.Identity(),
            nn.LeakyReLU(0.25)
        )

    def forward(self, x1,x2):
        x1 = self.conv(x1)
        if x1.shape[-1] != x2.shape[-1]:
            diff = x1.shape[-1] - x2.shape[-1]
            x1 = x1[..., diff//2 : diff//2 + x2.shape[-1]] # centered cropping Ronneberger et al., 2015
        #concatenation
        x = torch.cat([x2, x1], dim=1)
        x = self.post(x)
        return x

class U_net(nn.Module):
    def __init__(self, in_channel,hidden_unit,out_channel,list_kernel, apply_Dropout=True):
        super().__init__()

        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        for i in range(len(list_kernel)):
            pad = list_kernel[i]//2
            self.downs.append(down_block(in_channel, hidden_unit, kernel=list_kernel[i], padding=pad, stride=2))
            in_channel = hidden_unit
        list_kernel = list_kernel[::-1]
        for i in range(len(list_kernel)-1):
            pad = list_kernel[i]//2
            self.ups.append(up_block(hidden_unit, hidden_unit, kernel=list_kernel[i], padding=pad, stride=2, apply_Dropout=apply_Dropout))
        
        self.final = nn.ConvTranspose1d(hidden_unit, out_channel, kernel_size=list_kernel[-1], padding=list_kernel[-1]//2, stride=1)
        
    def forward(self, x):
        skip = [] #x for skip connections
        for down in self.downs:
            skip.append(x)   # Before downsampling
            x = down(x)
        print("encoder", x.shape)
        skip = skip[::-1] # reverse for skip connections
        for i in range(len(self.ups)):
            x = self.ups[i](x, skip[i])
        print("decoder", x.shape)
        x = self.final(x)
        print("final", x.shape)
        return x  # (batch,1, 4961)

In [None]:
#!pip install torchmetrics

In [None]:
#Metrics 
# https://arxiv.org/pdf/2101.01666


In [None]:
import torch.optim as optim
import torchmetrics
from torchmetrics.classification import BinaryPrecision, BinaryRecall
import tqdm as tqdm
# Instanciate the model
net = U_net(
    in_feature=1,
    hidden_unit=10,
    out_feature=1,
    list_kernel=[9,9,6,6,3,3],
    apply_Dropout=True
)

loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(net.parameters(), lr=1e-3)
precision_metric = BinaryPrecision()
recall_metric = BinaryRecall()
# device
net.to(device)

epochs = 5

# Training loop
net.train()
for epoch in tqdm.tqdm(range(epochs)):
    for x_batch, y_batch in train_loader:
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        logits = net(x_batch)

        optimizer.zero_grad()
        loss = loss_fn(logits, y_batch)
        loss.backward()
        optimizer.step()
        probs = torch.sigmoid(logits)
        preds = (probs > 0.5).int()
        
        # Metrics
        #precision = precision_metric(preds, y_batch.int())
        #recall = recall_metric(preds, y_batch.int())
        print("precision", precision)
        print("recall", recall)

    


In [None]:
# Evaluation loop

In [None]:
#save the model