<a href="https://colab.research.google.com/github/gerritgr/nextaid_bdprocess/blob/main/Full_Autoencoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [1]:
%matplotlib inline

In [55]:
import tqdm as notebook_tqdm
#from .autonotebook import tqdm as notebook_tqdm
import random
import networkx as nx
import os
import pickle 
import time
import pandas as pd
import glob
import numpy as np
import matplotlib.pyplot as plt
import math
import sys

import torch
from torch import nn
import torch.nn.functional as F

import torch_geometric
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

HOME ='./'
THIS_FILE = 'Full_Autoencoder.ipynb'   # very ugly to get automatically, needed to autosave

### Hyperparams

In [47]:
LATENT_DIM = 128
NUM_ATOMTYPES = 43 #https://www.blopig.com/blog/2022/02/how-to-turn-a-smiles-string-into-a-molecular-graph-for-pytorch-geometric/
LEARNING_RATE = 0.000001
BATCH_SIZE = 32#
ACCEPT_PEAK_THRESHOLD = 0.3
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cpu')

# Utils

## Logging

In [84]:
import logging
logging.basicConfig( format = '%(asctime)s  %(levelname)-10s %(processName)s  %(name)s %(message)s', level = logging.INFO, filename = THIS_FILE+"_my.log")

    
def printlog(*args, **kw):
    output = ', '.join([str(a) for a in args])
    kwargs = [str((k,v)) for k,v in kw.items()]
    output = output + ' - '+ ', '.join(kwargs)
    logging.info(output) 
    try:
        print(*args, **kw)
    except:
        print(output)

## Create File

In [5]:
if not os.path.isfile('HIVspectra_graph_data_full.pickle'):
    import shutil
    printlog('unpack HIVspectra_graph_data_full')
    shutil.unpack_archive('HIVspectra_graph_data_full.zip', './')

## Seed

In [6]:
# stack overflow
def str_to_float(s, encoding="utf-8"):
  from zlib import crc32
  def bytes_to_float(b):
    return float(crc32(b) & 0xffffffff) / 2**32
  return bytes_to_float(s.encode(encoding))


# very important to make results reproducible
def set_seed(exp_name):
    name_as_int = int(str_to_float(exp_name)*10000000)
    np.random.seed(name_as_int)
    torch.random.manual_seed(name_as_int)
    random.seed(name_as_int) 
    
set_seed('1234')

## Misc

In [7]:
def setup_experiment(exp_name):
    set_seed(exp_name)
    exp_name = HOME + exp_name
    os.system('mkdir ' + exp_name)
    os.system('mkdir '+ exp_name + '/weights')
    os.system('cp "'+THIS_FILE+'" '+exp_name+'/main_notebook.ipynb')
    return exp_name


def load_model(exp_folder_path):
  filepath = sorted(list(glob.glob(exp_folder_path + '/weights/training_state*')))
  if len(filepath) == 0:
    return None
  state = torch.load(filepath[-1], map_location=DEVICE)
  printlog('Found state file: ', filepath)
  return state

def plot_inputoutput(input_output, path):
    x_values = input_output[0]
    y_values = input_output[1]
    try:
        x_values = x_values[:3000]
        y_values = y_values[:3000]
    except:
        pass
    plt.clf()
    plt.scatter(x_values, y_values, edgecolors='blue', color='None', alpha=0.3)
    plt.xlabel('Real Peak Num')
    plt.ylabel('Predicted Peak Num')
    plt.title('Training Data')
    plt.savefig(path)
    

def reduce_vector(y, length_new=128):
    assert(length_new <= y.numel())
    
    y_flat = y.flatten()
    y_new = y_flat[:length_new] * 0.0 # to get same type/device
    length_old = y_flat.numel()
    for i in range(length_old):
        i_new = (float(i)/length_old)*length_new
        i_new = int(i_new)
        y_new[i_new] += y_flat[i]
        
    return y_new
        
    

# Loss and Acc

## Custom Multi-Scale Encoder Loss and Acc

In [8]:
def split_list(length_of_list):
    sublist_num = 0.5
    index_list = list()
    while True:
        sublist_num = int(sublist_num*2)
        sublist_len = int(length_of_list/sublist_num)
        assert(int(sublist_len * sublist_num) == int(length_of_list))
        for i in range(sublist_num):
            index_list.append((i*sublist_len, (i+1)*sublist_len)) 
        if int(sublist_num) == int(length_of_list):
            break
    return index_list

# accepts only flattened tensors
def fill_to_pow2(x):
    l = x.shape[1]
    p = int(math.log2(l))
    l_new = int(l - 2**p)
    if l_new == 0:
        return x
    l_diff = int(2**(p+1) - l)
    mean_vec =  torch.mean(x, dim=1, keepdim =True)
    x = torch.cat((x, mean_vec.repeat(1, l_diff)), dim=1) 
    return x

# operates in-place (flatten)
# only works on vectors of length 2^x or will extend tensors accordingly
def compute_diff(l1, l2, advanced_weight=True):
    assert l1.shape == l2.shape
    assert len(l1.shape) in [1, 2]
    # first dim is batch
    if len(l1.shape) == 1:
        l1 = l1.reshape(1,-1)
        l2 = l2.reshape(1,-1)
    

    l1 = fill_to_pow2(l1)
    l2 = fill_to_pow2(l2)

    indices = split_list(l1.shape[1])
    diff = torch.sum(l1, dim=1, keepdim=True)*0.0  # to have tensor on correct decice and tpye
    for (i1, i2) in indices:
        sum1 = torch.sum(l1[:, i1:i2], dim=1, keepdim=True)
        sum2 = torch.sum(l2[:, i1:i2], dim=1, keepdim=True)
        diff_i = (sum1-sum2)**2
        if advanced_weight:
            diff_i = diff_i * (i2-i1)/len(l1)
        diff = diff + diff_i 
    return diff


def compute_acc(out, gt, peak_threshold = ACCEPT_PEAK_THRESHOLD):
    try:
        out = out.flatten().detach().cpu().tolist()
    except:
        pass
    try:
        gt = gt.flatten().detach().cpu().tolist()
    except:
        pass
    
    pred = [1.0 if out[i] > peak_threshold else 0.0 for i in range(len(out))]
    gt = [1.0 if gt[i]>0.5 else 0.0 for i in range(len(gt))]
    
    gt_pred_intersec =  [1.0 if gt[i]>0.5 and pred[i] >0.5 else 0.0 for i in range(len(gt))]
    gt_pred_union =  [1.0 if gt[i]>0.5 or pred[i] >0.5 else 0.0 for i in range(len(gt))]
    
    # compute jaccard
    return np.sum(gt_pred_intersec)/np.sum(gt_pred_union)
    
    

## GNN-Based Graph Distance Loss

# GNN

## Encoder

In [56]:
from torch.nn import ModuleList, Linear, ReLU, Sequential
from torch_geometric.nn import PNAConv, BatchNorm, global_add_pool

class GNN_ENC(torch.nn.Module):
    def __init__(self, hidden_channels=LATENT_DIM, node_input_size=79, edge_input_size=10, out_channel=128, depth_graph=10, depth_mlp=5):
        super(GNN_ENC, self).__init__()
        
        # Graph layers
        self.convs = ModuleList()
        self.convs.append(PDNConv(node_input_size, hidden_channels, edge_dim=edge_input_size, hidden_channels=hidden_channels))
        for _ in range(depth_graph-1): 
            self.convs.append(PDNConv(hidden_channels, hidden_channels, edge_dim=edge_input_size, hidden_channels=hidden_channels))
        
        # Final MLP        
        mlp_list = list()
        for _ in range(depth_mlp-1):
            mlp_list.append(Linear(hidden_channels, hidden_channels))
            mlp_list.append(ReLU())
        mlp_list.append(Linear(hidden_channels, out_channel))
        self.mlp = Sequential(*mlp_list)
        

        
    def forward(self, x, edge_index, edge_attr,  batch):

        # Graph layers
        for conv in self.convs[:-1]:
            x = F.relu(conv(x, edge_index, edge_attr))
        x = F.relu(conv(x, edge_index, edge_attr))
        
        # Readout layer
        x = global_add_pool(x, batch)  


        # 3. Apply a final classifier
        x = F.dropout(x, p=0.1, training=self.training)
        x = self.mlp(x)
        x = x**2
        
        return x


## Decoder

In [57]:
# UNUSED
class GNN_DEC(torch.nn.Module):
    def __init__(self, hidden_channels=LATENT_DIM, max_number_of_nodes=10, depth_mlp=15):
        super(GNN_DEC, self).__init__()
        
        self.out_channel = int(max_number_of_nodes*(max_number_of_nodes-1)/2+max_number_of_nodes)  # size of upper triangular matrix with diagonal
        
        # MLP        
        mlp_list = list()
        for _ in range(depth_mlp-1):
            mlp_list.append(Linear(hidden_channels, hidden_channels))
            mlp_list.append(ReLU())
        mlp_list.append(Linear(hidden_channels, self.out_channel))
        self.mlp = Sequential(*mlp_list)
        
        
    def forward(self, x):

        x = self.mlp(x)
        
        return x

### Decoder Atom Count Vec

In [58]:
class GNN_DEC_Atomcount(torch.nn.Module):
    def __init__(self, hidden_channels=LATENT_DIM, number_atomtypes=NUM_ATOMTYPES, depth_mlp=15):
        super(GNN_DEC_Atomcount, self).__init__()
        
        self.out_channel = number_atomtypes
        
        # MLP        
        mlp_list = list()
        for _ in range(depth_mlp-1):
            mlp_list.append(Linear(hidden_channels, hidden_channels))
            mlp_list.append(ReLU())
        mlp_list.append(Linear(hidden_channels, self.out_channel))
        self.mlp = Sequential(*mlp_list)
        
        
    def forward(self, x):

        atom_count_vec = self.mlp(x)
        
        return atom_count_vec

## Autoencoder

In [75]:
class Autoencoder(torch.nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        self.encoder = GNN_ENC().to(DEVICE)
        self.decoder = GNN_DEC_Atomcount().to(DEVICE)


    def forward(self, x, edge_index, edge_attr, batch):
        pred_spectrum = self.encoder(x, edge_index, edge_attr, batch)
        pred_atom_count_vec = self.decoder(pred_spectrum)

        return pred_spectrum, pred_atom_count_vec

# Visualize

In [63]:
# make sure batch size is 1
@torch.no_grad()
def visualize(data_loader, model, path):
    model.eval()

    for i, data in enumerate(data_loader):
      data = data.to(DEVICE)
      out = model(data.x, data.edge_index, data.edge_attr, data.batch)
      out = out.flatten().detach().cpu().tolist()
      y = data.y.flatten().detach().cpu().tolist()

      plt.close()
      plt.plot(out, alpha=0.6)
      plt.plot(y, alpha=0.5)
      plt.savefig(path.replace('NUM', str(i).zfill(4)),  bbox_inches='tight', dpi=300)
        
      plt.clf()
    
    
      pred = [(i, 1.0) if out[i] > ACCEPT_PEAK_THRESHOLD else (i,0.0) for i in range(len(out))]
      gt = [(i, 1.0) if y[i]>0.5 else (i, 0.0) for i in range(len(y))]
    
      #gt = [(i, y[i]) for i in range(len(y)) if y[i] > 0.9]
      #pred = [(i, -out[i]) for i in range(len(out)) if out[i] > 0.40 and out[i] >= out[max(0,i-1)] and out[i] >= out[min(len(out)-1, i+1)]]
      #plt.scatter([x[0] for x in gt], [x[1] for x in gt], alpha=0.8)
      #plt.scatter([x[0] for x in pred], [x[1] for x in pred], alpha=0.8)
    
      plt.scatter([0, len(out)], [0, 1.0], alpha=0.0, c='black')  #dummy
      for x,y in gt:
          plt.axvline(x=x, c='red', alpha=0.5, ls='--', lw=3, label='gt')
      for x,y in pred:
          plt.axvline(x=x, c='blue', alpha=0.5, label='pred')
            
      plt.legend()      
      plt.savefig(path.replace('NUM', 'scatter_'+str(i).zfill(4)),  bbox_inches='tight', dpi=300)

    return 

# Training

In [81]:
def train(model, train_loader, optimizer):
    model.train()
    spectrum_loss_list = list()
    count_loss_list = list()
    l1_loss = nn.L1Loss()

    for data in train_loader:  
        data = data.to(DEVICE)
        optimizer.zero_grad()  
        pred_spectrum, pred_atom_count_vec = model(data.x, data.edge_index, data.edge_attr, data.batch)
        gt_spectrum = data.y.reshape(pred_spectrum.shape)

        # latent loss to optimize spectra
        spectrum_loss = compute_diff(gt_spectrum, pred_spectrum)
        spectrum_loss = torch.mean(spectrum_loss) # reduce to dim 1

        # count loss to optimize reconstruction
        gt_atom_count_vec = global_add_pool(data.x[:,:NUM_ATOMTYPES], data.batch)
        count_loss = l1_loss(gt_atom_count_vec, pred_atom_count_vec)

        loss = spectrum_loss + count_loss

        loss.backward()  
        optimizer.step()  

        spectrum_loss_list.append(spectrum_loss.item())
        count_loss_list.append(count_loss.item())
    return np.mean(spectrum_loss_list), np.mean(count_loss_list) # only correct when each batch equally sized


@torch.no_grad()
def test(model, data_loader):
    model.eval()
    spectrum_acc_list = list()
    count_loss_list = list()
    l1_loss = nn.L1Loss()

    for data in data_loader:
        data = data.to(DEVICE)
        pred_spectrum, pred_atom_count_vec = model(data.x, data.edge_index, data.edge_attr, data.batch)
        
        # spectrum loss
        y = data.y.reshape(pred_spectrum.shape)
        for j in range(pred_spectrum.shape[0]):
            out_line = pred_spectrum[j,:]
            gt_line = y[j,:]
            acc = compute_acc(out_line, gt_line)
            spectrum_acc_list.append(acc)

        # count loss to optimize reconstruction
        gt_atom_count_vec = global_add_pool(data.x[:,:NUM_ATOMTYPES], data.batch)
        count_loss = l1_loss(gt_atom_count_vec, pred_atom_count_vec)
        count_loss_list.append(count_loss.item())
        
    return  np.mean(spectrum_acc_list), np.mean(count_loss_list)

### Putting everything together

In [99]:
def run_experiment(exp_name, model, train_loader, test_loader, optimizer, viz_loader = None, epoch_num = 1000):
    set_seed(exp_name)
    printlog('Start experiment: ', exp_name)
    epoch_num += 1 # because we start with epoch 1
    exp_folder_path = setup_experiment(exp_name)
    epoch_start = 1

    # load model
    state = load_model(exp_folder_path)
    if state is not None:
        optimizer.load_state_dict(state['optimizer'])
        model.load_state_dict(state['state_dict'])
        epoch_start = state['epoch']
        
    for epoch in range(epoch_start, epoch_num):
        start = time.time()
        spectrum_loss_train, count_loss_train = train(model, train_loader, optimizer)
        end = time.time()
        epoch_time = end - start
        if epoch == epoch_start or epoch % 100 == 0 or epoch == epoch_num-1 or epoch<10:
            spectrum_acc_train, count_loss_train = test(model, train_loader)
            spectrum_acc_test, count_loss_test = test(model, test_loader)
            s = repr(f'Epoch: {epoch:03d}, Spectrum Loss: {spectrum_loss_train:.4f}, Count Loss: {count_loss_train:.4f}, Train Jaccard: {spectrum_acc_train:.4f}, Test Jaccard: {spectrum_acc_test:.4f}, Epoch Time: {epoch_time:.5f}')
            printlog(s)

            state = {'optimizer': optimizer.state_dict(), 'epoch': epoch+1, 'state_dict': model.state_dict()} 
            torch.save(state, exp_folder_path+'/weights/training_state.pickle')
            if viz_loader is not None:
                visualize(viz_loader, model, exp_folder_path+'/prediction_{}_NUM.png'.format(str(epoch).zfill(10)))

    

# Load Data

In [92]:
graph_data = None
try:
    graph_data = pickle.load(open('HIVspectra_graph_data_reduced_{}.pickle'.format(int(LATENT_DIM)), "rb"))
    print('read reduced graph file')
except:
    print('create reduced file')
    graph_data = pickle.load(open('HIVspectra_graph_data_full.pickle', "rb"))
    for graph in graph_data:
        graph.y = reduce_vector(graph.y, length_new = LATENT_DIM)
    pickle.dump(graph_data, open('HIVspectra_graph_data_reduced_{}.pickle'.format(int(LATENT_DIM)), "wb"))


read reduced graph file


In [93]:
c = int(len(graph_data)*0.8)
graph_data_train = graph_data[0:c]
graph_data_test = graph_data[c:]

train_loader = DataLoader(graph_data_train, batch_size=BATCH_SIZE, shuffle=True)  
test_loader = DataLoader(graph_data_test, batch_size=BATCH_SIZE)
viz_loader = DataLoader(graph_data_train[:20], batch_size=1)

# Experiments

### Exp 1 

In [None]:
GNN_autoenc = Autoencoder()
GNN_autoenc = GNN_autoenc.to(DEVICE)
optimizer = torch.optim.Adam(GNN_autoenc.parameters(), lr=LEARNING_RATE)


run_experiment('exp1_autoenc', GNN_autoenc, train_loader, test_loader, optimizer, viz_loader = None, epoch_num = 10000)

Start experiment:  exp1_autoenc


mkdir: cannot create directory ‘./exp1_autoenc’: File exists
mkdir: cannot create directory ‘./exp1_autoenc/weights’: File exists


'Epoch: 001, Spectrum Loss: 1706.4960, Count Loss: 0.6233, Train Jaccard: 0.0000, Test Jaccard: 0.0000, Epoch Time: 165.37094'
