## Imports

In [None]:
import torch
from torch.utils.data import DataLoader, random_split
from resnet import ResNet
from training import run_epoch
import matplotlib.pyplot as plt
from data import RamanSpectraDataset  
# from data2 import RamanSpectraDataset  
import numpy as np
import optuna
from optuna.trial import Trial
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from augment import apply_augmentation

## ResNet

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()

        # Layers
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=5,
            stride=stride, padding=2, dilation=1, bias=False)
        self.bn1 = nn.BatchNorm1d(num_features=out_channels)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=5,
            stride=1, padding=2, dilation=1, bias=False)
        self.bn2 = nn.BatchNorm1d(num_features=out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, kernel_size=1,
                    stride=stride, bias=False),
                nn.BatchNorm1d(out_channels))

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, hidden_sizes, num_blocks, input_dim=1000,
        in_channels=64, n_classes=1):
        super(ResNet, self).__init__()
        assert len(num_blocks) == len(hidden_sizes)
        self.input_dim = input_dim
        self.in_channels = in_channels
        self.n_classes = n_classes
        
        self.conv1 = nn.Conv1d(1, self.in_channels, kernel_size=5, stride=1,
            padding=2, bias=False)
        self.bn1 = nn.BatchNorm1d(self.in_channels)
        
        # Flexible number of residual encoding layers
        layers = []
        strides = [1] + [2] * (len(hidden_sizes) - 1)
        for idx, hidden_size in enumerate(hidden_sizes):
            layers.append(self._make_layer(hidden_size, num_blocks[idx],
                stride=strides[idx]))
        self.encoder = nn.Sequential(*layers)

        self.z_dim = self._get_encoding_size()
        self.linear = nn.Linear(self.z_dim, self.n_classes)


    def encode(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.encoder(x)
        z = x.view(x.size(0), -1)
        return z

    def forward(self, x):
        z = self.encode(x)
        return self.linear(z)


    def _make_layer(self, out_channels, num_blocks, stride=1):
        strides = [stride] + [1] * (num_blocks - 1)
        blocks = []
        for stride in strides:
            blocks.append(ResidualBlock(self.in_channels, out_channels,
                stride=stride))
            self.in_channels = out_channels
        return nn.Sequential(*blocks)

    def _get_encoding_size(self):
        """
        Returns the dimension of the encoded input.
        """
        temp = Variable(torch.rand(1, 1, self.input_dim))
        z = self.encode(temp)
        z_dim = z.data.size(1)
        return z_dim


def add_activation(activation='relu'):
    """
    Adds specified activation layer, choices include:
    - 'relu'
    - 'elu' (alpha)
    - 'selu'
    - 'leaky relu' (negative_slope)
    - 'sigmoid'
    - 'tanh'
    - 'softplus' (beta, threshold)
    """
    if activation == 'relu':
        return nn.ReLU()
    elif activation == 'elu':
        return nn.ELU(alpha=1.0)
    elif activation == 'selu':
        return nn.SELU()
    elif activation == 'leaky relu':
        return nn.LeakyReLU(negative_slope=0.1)
    elif activation == 'sigmoid':
        return nn.Sigmoid()
    elif activation == 'tanh':
        return nn.Tanh()
    # SOFTPLUS DOESN'T WORK with automatic differentiation in pytorch
    elif activation == 'softplus':
        return nn.Softplus(beta=1, threshold=20)

## Data Loader

In [None]:
# === Dataset Definition ===
class RamanSpectraDataset(Dataset):
    def __init__(self, root_dir, augment=False, offline_aug=False, num_aug=2):
        self.samples = []
        self.augment = augment
        self.offline_aug = offline_aug
        self.num_aug = num_aug

        for folder in os.listdir(root_dir):
            folder_path = os.path.join(root_dir, folder)
            if not os.path.isdir(folder_path):
                continue
            try:
                conc = float(f"1e-{folder}")
                label = np.log10(conc)  # e.g., -5.0, -6.0
            except ValueError:
                continue

            #for fname in os.listdir(folder_path):
                if fname.endswith('.txt'):
                    fpath = os.path.join(folder_path, fname)
                    self.samples.append((fpath, label))
            
            for fname in os.listdir(folder_path):
                if fname.endswith('.txt'):
                    fpath = os.path.join(folder_path, fname)
                    self.samples.append((fpath, label))

                if self.offline_aug:
                    for i in range(self.num_aug):
                        self.samples.append((fpath, label, True))  # Third item marks "needs augmentation")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        entry = self.samples[idx]

        if len(entry) == 3:
            fpath, label, to_augment = entry
        else:
            fpath, label = entry
            to_augment = False

        data = np.loadtxt(fpath)
        intensities = data[:, 1]

        if self.augment or to_augment:
            intensities = apply_augmentation(intensities)

        # Normalize
        intensities = (intensities - intensities.mean()) / (intensities.std() + 1e-8)
        x = torch.tensor(intensities[np.newaxis, :], dtype=torch.float32)
        y = torch.tensor(label, dtype=torch.float32)
        return x, y
    
    #def __getitem__(self, idx):
        fpath, label = self.samples[idx]
        data = np.loadtxt(fpath)
        intensities = data[:, 1]

        if self.augment:
            intensities = apply_augmentation(intensities)

        intensities = (intensities - intensities.mean()) / (intensities.std() + 1e-8)
        x = torch.tensor(intensities[np.newaxis, :], dtype=torch.float32)
        y = torch.tensor(label, dtype=torch.float32)
        return x, y




    #def __getitem__(self, idx):
        fpath, label = self.samples[idx]
        data = np.loadtxt(fpath)
        intensities = data[:, 1]
        intensities = (intensities - intensities.mean()) / (intensities.std() + 1e-8)
        x = torch.tensor(intensities[np.newaxis, :], dtype=torch.float32)
        y = torch.tensor(label, dtype=torch.float32)
        return x, y

# === Load & Split Dataset ===
def get_loaders(root_path, batch_size=32):
    dataset = RamanSpectraDataset(root_path)

    train_size = int(0.8 * len(dataset))
    val_size   = int(0.1 * len(dataset))
    test_size  = len(dataset) - train_size - val_size

    train_set, val_set, test_set = random_split(dataset, [train_size, val_size, test_size])

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(val_set, batch_size=batch_size)
    test_loader  = DataLoader(test_set, batch_size=batch_size)

    return train_loader, val_loader, test_loader


## Training

In [None]:
def run_epoch(epoch, model, dataloader, cuda, training=False, optimizer=None):
    if training:
        model.train()
    else:
        model.eval()

    total_loss = 0
    total_samples = 0

    all_preds = []
    all_targets = []

    for batch_idx, (inputs, targets) in enumerate(tqdm(dataloader, desc=f"Epoch {epoch} {'Train' if training else 'Val'}")):
        if cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = Variable(inputs), Variable(targets.float())

        outputs = model(inputs).squeeze()
        loss = nn.MSELoss()(outputs, targets)

        if training:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        total_loss += loss.item() * targets.size(0)
        total_samples += targets.size(0)

        all_preds.extend(outputs.detach().cpu().numpy())
        all_targets.extend(targets.detach().cpu().numpy())

    avg_loss = total_loss / total_samples

    # === Optional: Round predictions to nearest integer log10 value ===
    # You can comment or delete this block if it hurts performance
    all_preds = [min(-5, max(-9, int(round(pred)))) for pred in all_preds]

    mae = mean_absolute_error(all_targets, all_preds)
    rmse = root_mean_squared_error(all_targets, all_preds)
    r2 = r2_score(all_targets, all_preds)

    return avg_loss, mae, rmse, r2


def get_predictions(model, dataloader, cuda, get_probs=False):
    preds = []
    model.eval()
    for batch_idx, (inputs, targets) in enumerate(dataloader):
        if cuda: inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = Variable(inputs), Variable(targets.float())
        outputs = model(inputs)
        if get_probs:
            probs = torch.nn.functional.softmax(outputs, dim=1)
            if cuda: probs = probs.data.cpu().numpy()
            else: probs = probs.data.numpy()
            preds.append(probs)
        else:
            predicted = outputs.squeeze()
            preds += list(predicted.detach().cpu().numpy())
    if get_probs:
        return np.vstack(preds)
    else:
        return np.array(preds)