In [None]:
import re
import yaml
import glob
import torch
import time
import torchmetrics
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
from functools import partial
from collections import defaultdict
from torch.utils.data import random_split, Dataset, DataLoader, Subset
from torchview import draw_graph


if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

print(f'using device: {device}')


with open('config.yaml', 'r') as f:
    config = yaml.safe_load(f)
    

OUTPUT_DIR = config['output_dir']
TRAIN_SIZE = config['train_size']
VALID_SIZE = config['valid_size']
BATCH_SIZE = config['batch_size']


EPS = 1e-8
SEED = 42


def sum_normalization_func(x):
    x = x / (x.sum() + EPS)
    return x

def log_transform_func(x):
    x = torch.log(x + EPS)
    return x

def standardization_func(x, train_mean, train_std):
    x = (x - train_mean) / (train_std + EPS)
    return x

def slice_func(x, ymin, ymax, xmin, xmax):
    return x[ymin:ymax,xmin:xmax]

def build_pairs(matrix_dir: str):
    """
    Returns list of dicts:
    [{'sid': sid, 'positive': path, 'neg': path}, ...]
    """
    npy_files = glob.glob(f'{matrix_dir}*.npy')

    mapping = defaultdict(lambda: {'negative': '', 'positive': ''})
    for p in npy_files:
        sid = re.search(r'EE\d+', p)[0]    
        if not sid:
            continue

        is_neg = 'negative' in p.lower()
        if is_neg:
            mapping[sid]['negative'] = p
        else:
            mapping[sid]['positive'] = p

    result = []
    for sid, paths in mapping.items():
        result.append({'sid': sid, 'positive': paths['positive'], 'negative': paths['negative']})
    return result

def split_pairs_torch(pairs, seed=SEED):
    n = len(pairs)
    train_size = n * TRAIN_SIZE // 100
    valid_size = n * VALID_SIZE // 100
    test_size = n - train_size - valid_size

    train_pairs, valid_pairs, test_pairs = random_split(
        pairs,
        [train_size, valid_size, test_size],
        generator=torch.Generator().manual_seed(seed)
    )
    return train_pairs, valid_pairs, test_pairs

def load_train_pairs(pairs, do_sum_normalization, do_log_transform, slice_params=None):
    matrixes = []
    for p in pairs:
        for path in (p['positive'], p['negative']):
            x = torch.from_numpy(np.load(path).astype(np.float32))
            
            if slice_params is not None:
                x = slice_func(x, **slice_params)
            
            if do_sum_normalization:
                x = sum_normalization_func(x)
            if do_log_transform:
                x = log_transform_func(x)
            matrixes.append(x)
    X = torch.stack(matrixes, dim=0)
    return X


class MatrixDataset(Dataset):
    def __init__(
        self,
        pairs,
        train_mean, 
        train_std,
        do_standardization,
        do_sum_normalization,
        do_log_transform,
        slice_params=None,
    ):
        self.items = []
        for p in pairs:
            self.items.append((p['positive'], 1, p['sid']))
            self.items.append((p['negative'], 0, p['sid']))
        
        self.train_mean = train_mean
        self.train_std = train_std
        
        self.do_standardization = do_standardization
        self.do_sum_normalization = do_sum_normalization
        self.do_log_transform = do_log_transform
        self.slice_params = slice_params
        
    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        path, y, _ = self.items[idx]
        x = torch.from_numpy(np.load(path)).float() # [H,W]
        
        if self.slice_params is not None:
            x = slice_func(x, **self.slice_params)
        
        if self.do_sum_normalization:
            x = sum_normalization_func(x)
        if self.do_log_transform:
            x = log_transform_func(x)
            
        x = x.unsqueeze(0) # [1, H, W]
        
        if self.do_standardization and self.train_mean is not None and self.train_std is not None:
            x = standardization_func(x, self.train_mean, self.train_std)
            
        return x, torch.tensor(y, dtype=torch.float32)
    

def get_dataloaders(
    do_standardization=False, 
    do_sum_normalization=False, 
    do_log_transform=False, 
    is_tiny=False, 
    slice_params=None,
):
    pairs = build_pairs(OUTPUT_DIR)
    train_pairs, valid_pairs, test_pairs = split_pairs_torch(pairs, seed=SEED)
    
    train_mean, train_std = None, None
    if do_standardization:
        X = load_train_pairs(train_pairs, do_sum_normalization, do_log_transform, slice_params)
        train_mean = X.mean().item()
        train_std = X.std(unbiased=True).item()

    DefaultMatrixDataset = partial(
        MatrixDataset, 
        train_mean=train_mean, 
        train_std=train_std,
        do_standardization=do_standardization,
        do_sum_normalization=do_sum_normalization,
        do_log_transform=do_log_transform,
        slice_params=slice_params,
    )
    train_ds = DefaultMatrixDataset(train_pairs)
    valid_ds = DefaultMatrixDataset(valid_pairs)
    test_ds = DefaultMatrixDataset(test_pairs)
    
    if is_tiny:
        train_pos_idx = [i for i, item in enumerate(train_ds.items) if item[1] == 1.0][:32]
        train_neg_idx = [i for i, item in enumerate(train_ds.items) if item[1] == 0.0][:32]
        train_ids = train_pos_idx + train_neg_idx
        
        valid_pos_idx = [i for i, item in enumerate(valid_ds.items) if item[1] == 1.0][:8]
        valid_neg_idx = [i for i, item in enumerate(valid_ds.items) if item[1] == 0.0][:8]
        valid_ids = valid_pos_idx + valid_neg_idx

        tiny_train_ds = Subset(train_ds, train_ids)
        tiny_valid_ds = Subset(valid_ds, valid_ids)
        
        tiny_train_loader = DataLoader(tiny_train_ds, batch_size=4, shuffle=True)
        tiny_valid_loader = DataLoader(tiny_valid_ds, batch_size=4)
        print(f'sizes:\n {len(tiny_train_ds)} train\n {len(tiny_valid_ds)}')
        return tiny_train_loader, tiny_valid_loader, None

    DefaultDataLoader = partial(DataLoader, batch_size=BATCH_SIZE)
    train_loader = DefaultDataLoader(train_ds, shuffle=True)
    valid_loader = DefaultDataLoader(valid_ds)
    test_loader = DefaultDataLoader(test_ds)
    
    print(f'sizes:\n {len(train_ds)} train\n {len(valid_ds)} val\n {len(test_ds)} test\n train mean: {train_mean}\n train std: {train_std}')
    
    if do_standardization:
        total_sum, total_sumsq, total_count = 0.0, 0.0, 0
        for X_batch, y_batch in train_loader:
            xb = X_batch.float() # (B, 1, H, W)
            total_sum += xb.sum().item()
            total_sumsq += (xb * xb).sum().item()
            total_count += xb.numel()

        sanity_mean = total_sum / total_count
        sanity_var = total_sumsq / total_count - sanity_mean * sanity_mean
        sanity_std = sanity_var ** 0.5

        print(f'After standardization\n mean: {sanity_mean:.2f}\n std: {sanity_std:.2f}')

    return train_loader, valid_loader, test_loader

def evaluate_tm(model, data_loader, metric):
    model.eval()
    metric.reset()
    with torch.no_grad():
        for X_batch, y_batch in data_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            y_pred = model(X_batch)
            metric.update(y_pred, y_batch)
    return metric.compute()

# early stopping training + save best model + scheduler
def train(
    model, optimizer, loss_fn, metric, train_loader, valid_loader, 
    n_epochs, patience=10, checkpoint_path=None, scheduler=None
):
    checkpoint_path = checkpoint_path or 'my_checkpoint.pt'
    history = {'train_losses': [], 'train_metrics': [], 'valid_metrics': []}
    best_metric = 0.0
    patience_counter = 0
    for epoch in range(n_epochs):
        total_loss = 0.0
        metric.reset()
        model.train()
        t0 = time.time()
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            y_pred = model(X_batch)
            loss = loss_fn(y_pred, y_batch)
            total_loss += loss.item()
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            metric.update(y_pred, y_batch)

        train_metric = metric.compute().item()
        valid_metric = evaluate_tm(model, valid_loader, metric).item()
        if valid_metric > best_metric:
            torch.save(model.state_dict(), checkpoint_path)
            best_metric = valid_metric
            best = ' (best)'
            patience_counter = 0
        else:
            patience_counter += 1
            best = ''

        t1 = time.time()
        history['train_losses'].append(total_loss / len(train_loader))
        history['train_metrics'].append(train_metric)
        history['valid_metrics'].append(valid_metric)
        print(f'Epoch {epoch + 1}/{n_epochs}, '
              f"train loss: {history['train_losses'][-1]:.4f}, "
              f"train metric: {history['train_metrics'][-1]:.4f}, "
              f"valid metric: {history['valid_metrics'][-1]:.4f}{best}"
              f' in {t1 - t0:.1f}s'
        )
        if scheduler is not None:
            if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                scheduler.step(valid_metric)
            else:
                scheduler.step()
        if patience_counter >= patience:
            print('Early stopping!')
            break

    model.load_state_dict(torch.load(checkpoint_path))
    return history

def build_model(m, seed=SEED):
    torch.manual_seed(seed)
    model = m().to(device)
    return model
    
def plot_training_progress(h):
    for plot in ('train_losses', 'valid_metrics', 'roc'):
        plt.figure(figsize=(8, 4))
        for history, opt_name in zip((h.values()), h.keys()):
            if plot == 'roc':
                plt.plot(history['fpr'], history['tpr'], label=opt_name, linewidth=3)
                plt.plot([0, 1], [0, 1], linestyle='--', linewidth=2, color='r', label='Random guess')
            else:
                plt.plot(history[plot], label=opt_name, linewidth=3)

        plt.grid()
#         plt.yscale('log')
        plt.xlabel(
            {
                'train_losses': 'Epochs', 
                'valid_metrics': 'Epochs', 
                'roc': 'False positive rate'
            }[plot]
        )
        plt.ylabel(
            {
                'train_losses': 'Training loss', 
                'valid_metrics': 'Validation AUC', 
                'roc': 'True positive rate'
            }[plot]
        )
        plt.legend(loc='upper right')
        plt.show()

history = defaultdict(str)

def compute_best_roc_data(model, valid_loader, roc_metric):
    model.eval()
    roc_metric.reset()
    with torch.no_grad():
        for X_batch, y_batch in valid_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            y_pred = model(X_batch)
            roc_metric.update(y_pred, y_batch.to(torch.int))

    fpr, tpr, thr = roc_metric.compute()
    return {
        'fpr': fpr,
        'tpr': tpr,
        'thresholds': thr,
    }

In [None]:
# SimpleCNNModel:

# SimpleCNNUnit (n times):
#  - Convolution layers (2d)
#  - LeakyReLU 
#  - MaxPool (2d)
# AdaptiveAvgPool2d (global average pooling layer)
# Flatten
# Linear (head -> outputs logits)

class SimpleCNNUnit(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, dropout=0.1):
        super().__init__()
        DefaultConv2d = partial(nn.Conv2d, kernel_size=3, stride=1, padding=1, bias=True)
        self.layers = nn.Sequential(
            DefaultConv2d(in_channels, out_channels, stride=stride),
#             nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.1),
            
            DefaultConv2d(out_channels, out_channels),
#             nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.1),
            
            nn.MaxPool2d(kernel_size=2),
#             nn.Dropout2d(dropout)
        )

    def forward(self, x):
        return self.layers(x)

class SimpleCNNModel(nn.Module):
    def __init__(self, base_channels=16, dropout=0.10, num_layers=2):
        super().__init__()
        layers = [
            SimpleCNNUnit(1, base_channels)
        ]
        
        prev_c = base_channels
        for c in [base_channels*2, base_channels*4]:
            layers.append(SimpleCNNUnit(prev_c, c))
            prev_c = c

        layers += [
            nn.AdaptiveAvgPool2d(output_size=1),
            nn.Flatten(),
            nn.Linear(prev_c, 1),
        ]
        self.cnn = nn.Sequential(*layers)

    def forward(self, x):
        return self.cnn(x).squeeze(-1)

n_epochs = 20
tag = 'cnn_1_800'

model = build_model(SimpleCNNModel)

optimizer = torch.optim.Adam(model.parameters())
binaryxentropy = nn.BCEWithLogitsLoss()  # outputs logits (-infinity <-> +infinity) -> use sigmoid to get probs
binary_auc = torchmetrics.classification.BinaryAUROC().to(device)
perf_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=2, factor=0.1)

train_loader, valid_loader, _ = get_dataloaders(slice_params={'ymin': 130, 'ymax': 200, 'xmin': 600, 'xmax':1400})

history_cnn = train(
    model,
    optimizer,
    binaryxentropy,
    binary_auc,
    train_loader,
    valid_loader,
    n_epochs,
    scheduler=perf_scheduler,
    checkpoint_path=f'{tag}.pt',
)

binary_roc = torchmetrics.classification.BinaryROC().to(device)
roc_data = compute_best_roc_data(model, valid_loader, roc_metric=binary_roc)

history[tag] = {**roc_data, **history_cnn}

In [None]:
# TweakedCNNModel:

class TweakedCNNUnit(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, dropout=0.1, pool_ks=(2,2)):
        super().__init__()
        DefaultConv2d = partial(nn.Conv2d, kernel_size=3, stride=1, padding=1, bias=True)
        self.layers = nn.Sequential(
            DefaultConv2d(in_channels, out_channels, stride=stride),
#             nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.1),
            
            DefaultConv2d(out_channels, out_channels),
#             nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.1),
            
            nn.MaxPool2d(kernel_size=pool_ks),
#             nn.Dropout2d(dropout)
        )

    def forward(self, x):
        return self.layers(x)

class TweakedCNNModel(nn.Module):
    def __init__(self, base_channels=32, dropout=0.10, num_layers=2):
        super().__init__()
        layers = [
            TweakedCNNUnit(1, base_channels, pool_ks=(2,1))
        ]
        
        prev_c = base_channels
        for pool_ks, c in zip([(2, 2), (1, 2)], [base_channels*2, base_channels*4]):
            layers.append(TweakedCNNUnit(prev_c, c, pool_ks=pool_ks))
            prev_c = c

        layers += [
            nn.AdaptiveAvgPool2d((1, 64)),
            nn.Flatten(),
            nn.Linear(prev_c*64, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, 1),
        ]
        self.cnn = nn.Sequential(*layers)

    def forward(self, x):
        return self.cnn(x).squeeze(-1)

n_epochs = 20
tag = 'cnn_1_2000_tweak'

model = build_model(TweakedCNNModel)

optimizer = torch.optim.Adam(model.parameters())
binaryxentropy = nn.BCEWithLogitsLoss()  # outputs logits (-infinity <-> +infinity) -> use sigmoid to get probs
binary_auc = torchmetrics.classification.BinaryAUROC().to(device)
perf_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=2, factor=0.1)

train_loader, valid_loader, _ = get_dataloaders(
    # do_standardization=True,
    # do_sum_normalization=True,
    # do_log_transform=True,
    is_tiny=True,
    slice_params={'ymin': 130, 'ymax': 200, 'xmin': 600, 'xmax':1400},
)

history_cnn = train(
    model,
    optimizer,
    binaryxentropy,
    binary_auc,
    train_loader,
    valid_loader,
    n_epochs,
    # scheduler=perf_scheduler,
    checkpoint_path=f'{tag}.pt',
)

binary_roc = torchmetrics.classification.BinaryROC().to(device)
roc_data = compute_best_roc_data(model, valid_loader, roc_metric=binary_roc)

history[tag] = {**roc_data, **history_cnn}

In [None]:
## MLPOnRelativeMidpointsModel (lengths are collapsed 300x2000 -> 1x2000):

# SimpleMLPUnit (n times):
#  - Linear
#  - ReLU 
# Linear (head -> outputs logits)

class SimpleMLPUnit(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_features, out_features),
            nn.ReLU()
        )

    def forward(self, x):
        return self.layers(x)

class MLPOnRelativeMidpointsModel(nn.Module):
    def __init__(self, n_inputs=2000, n_neurons=[2048, 1024, 1024, 512]):
        super().__init__()
        layers = [
            SimpleMLPUnit(n_in, n_out)
            for n_in, n_out in zip([n_inputs] + n_neurons, n_neurons)
        ] + [nn.Linear(n_neurons[-1], 1)]
        self.mlp = nn.Sequential(*layers)

    def forward(self, X):
        # X: (B, 1, 300, 2000)
        X = X.sum(dim=2)  # (B, 1, 300, 2000) -> (B, 1, 2000)
        X = X.squeeze(1) # (B, 1, 2000) -> (B, 2000)
        return self.mlp(X).squeeze(1) 

    
def use_he_init(module):
    if isinstance(module, nn.Linear):
        nn.init.kaiming_uniform_(module.weight)
        nn.init.zeros_(module.bias)
    
n_epochs = 100
tag = 'mlp_1_2000'

model = build_model(MLPOnRelativeMidpointsModel)
model.apply(use_he_init)

optimizer = torch.optim.Adam(model.parameters())
binaryxentropy = nn.BCEWithLogitsLoss()  # outputs logits (-infinity <-> +infinity) -> use sigmoid to get probs
binary_auc = torchmetrics.classification.BinaryAUROC().to(device)
perf_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=2, factor=0.1)

# train_loader, valid_loader, _ = get_dataloaders(slice_params={'ymin': 130, 'ymax': 200, 'xmin': 600, 'xmax':1400})
train_loader, valid_loader, _ = get_dataloaders()

history_relative_midpoints = train(
    model,
    optimizer,
    binaryxentropy,
    binary_auc,
    train_loader,
    valid_loader,
    n_epochs,
    scheduler=perf_scheduler,
    checkpoint_path=f'{tag}.pt',
)

binary_roc = torchmetrics.classification.BinaryROC().to(device)
roc_data = compute_best_roc_data(model, valid_loader, roc_metric=binary_roc)

history[tag] = {**roc_data, **history_relative_midpoints}

In [None]:
## MLPOnFragmentLengthsModel (relative midpoints are collapsed 300x2000 -> 300x1):

# SimpleMLPUnit (n times):
#  - Linear
#  - ReLU 
# Linear (head -> outputs logits)
class MLPOnFragmentLengthsModel(nn.Module):
    def __init__(self, n_inputs=300, n_neurons=[2048, 1024, 1024, 512]):
        super().__init__()
        layers = [
            SimpleMLPUnit(n_in, n_out)
            for n_in, n_out in zip([n_inputs] + n_neurons, n_neurons)
        ] + [nn.Linear(n_neurons[-1], 1)]
        self.mlp = nn.Sequential(*layers)

    def forward(self, X):
        X = X.sum(dim=3)
        # X = fragle_transforms(X)
        X = X.squeeze(1)
        return self.mlp(X).squeeze(1) 

    
def use_he_init(module):
    if isinstance(module, nn.Linear):
        nn.init.kaiming_uniform_(module.weight)
        nn.init.zeros_(module.bias)
    
n_epochs = 100
tag = 'mlp_300_1'

model = build_model(MLPOnFragmentLengthsModel)
model.apply(use_he_init)

optimizer = torch.optim.Adam(model.parameters())
binaryxentropy = nn.BCEWithLogitsLoss()  # outputs logits (-infinity <-> +infinity) -> use sigmoid to get probs
binary_auc = torchmetrics.classification.BinaryAUROC().to(device)
perf_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=2, factor=0.1)

train_loader, valid_loader, _ = get_dataloaders()

history_relative_midpoints = train(
    model,
    optimizer,
    binaryxentropy,
    binary_auc,
    train_loader,
    valid_loader,
    n_epochs,
    scheduler=perf_scheduler,
    checkpoint_path=f'{tag}.pt',
)

binary_roc = torchmetrics.classification.BinaryROC().to(device)
roc_data = compute_best_roc_data(model, valid_loader, roc_metric=binary_roc)

history[tag] = {**roc_data, **history_relative_midpoints}

In [None]:
plot_training_progress(history)

In [None]:
# input_size = None
for X_batch, y_batch in train_loader:
    input_size = X_batch.shape
    break
    
# model = build_model(partial(SimpleMLPUnit, in_features=2000, out_features=2048))
model = build_model(MLPOnRelativeMidpointsModel)
model.apply(use_he_init)

model_graph = draw_graph(
    model=model, 
    input_size=input_size, 
    device='meta', 
    expand_nested=True,
    save_graph=True,
    filename='mlp_relative_midpoints'
)
model_graph.visual_graph

In [None]:
print(f'Total num of learnable params: {sum(p.numel() for p in model.parameters() if p.requires_grad)}')