# Initialization

In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
from IPython.display import display, clear_output
from tqdm.auto import tqdm
import random
from sklearn.cluster import DBSCAN
import torch.nn.functional as F
import pickle as pkl
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.profiler import profile, record_function, ProfilerActivity
import os
import sys

ROOT = os.path.join("./")
sys.path.append(ROOT + "lib")

from helpers import *
from sourceset import SourceSet

torch.set_default_dtype(torch.float64)
device = "cuda" if torch.cuda.is_available() else "cpu"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
with open(ROOT + "processed_datasets/data_train.pt", "rb") as f:
  data_train = torch.load(f, map_location=device)
with open(ROOT + "processed_datasets/data_valid.pt", "rb") as f:
  data_valid = torch.load(f, map_location=device)

data_train, data_valid

(<sourceset.SourceSet at 0x17975c890>, <sourceset.SourceSet at 0x17990ef90>)

# Model

In [3]:
class FluxAnomalyPredictionLSTMDEPRECATED(nn.Module):
  def __init__(self, stride, dropout, bn=False, features=6, residual=0, out=4):
    super().__init__()

    self.stride = stride
    self.features = features
    self.residual = residual
    self.lstm_hidden_state_features = 14 # Needed for batchnorm 3, so defined here

    # Need Dropouts, Activation fns
    self.he = lambda x: nn.init.kaiming_normal_(x, nonlinearity='relu')
    self.av = nn.ReLU()
    self.drop = nn.Dropout(dropout)

    self.batchnorm1 = lambda x: x
    self.batchnorm2 = lambda x: x
    self.batchnorm3 = lambda x: x

    if bn:
      self.batchnorm1 = nn.BatchNorm1d(features)
      self.batchnorm2 = nn.BatchNorm1d(features)
      self.batchnorm3 = nn.BatchNorm1d(self.lstm_hidden_state_features)

      self.bn1 = lambda x: torch.transpose(self.batchnorm1(torch.transpose(x,1,2)),1,2)
      self.bn2 = lambda x: torch.transpose(self.batchnorm2(torch.transpose(x,1,2)),1,2)
      self.bn3 = lambda x: self.batchnorm3(x)
    else:
      self.bn1 = lambda x: x
      self.bn2 = lambda x: x
      self.bn3 = lambda x: x


    # Step 1

    # 5 x 64 X 64 x 5 -diag-> R^5 Vector
    self.conv_kernels_64d = nn.ParameterList([self.he(torch.randn(features, 64)) for i in range(8)])
    self.conv_biases_64d = nn.ParameterList([torch.randn(features) for i in range(8)])

    self.conv_kernel_16d = nn.Parameter(self.he(torch.randn(features, 16)))
    self.conv_bias_16d = nn.Parameter(torch.randn(features))

    self.conv_kernel_8d = nn.Parameter(self.he(torch.randn(features,8)))
    self.conv_bias_8d = nn.Parameter(torch.randn(features))

    self.four_max_pool = nn.MaxPool1d(4)

    # Step 1.5
    # self.certainty_fc_1 = nn.Parameter(torch.zeros(2,3))
    # self.certainty_fc_2 = nn.Parameter(torch.zeros(3,1))

    # Step 2
    self.widechannel_conv_kernel = nn.Parameter(self.he(torch.randn(features, 3)))
    self.widechannel_conv_bias = nn.Parameter(torch.randn(features))

    self.midchannel_conv_kernel = nn.Parameter(self.he(torch.randn(features, 3)))
    self.midchannel_conv_bias = nn.Parameter(torch.randn(features))

    self.narrowchannel_conv_kernel = nn.Parameter(self.he(torch.randn(features, 3)))
    self.narrowchannel_conv_bias = nn.Parameter(torch.randn(features))


    # Step 3
    self.pair_max_pool = nn.MaxPool1d(2)
    # Then concat all vectors into v \in R^24

    # Step 3.5 Residual connection
    self.__avgpool__ = nn.AvgPool1d(4, stride=4)

    # Step 4
    self.hidden_fc_1 = nn.Linear(3*4*features,3*2*features)
    self.hidden_fc_2 = nn.Linear(3*2*features,2*features)

    # Step 5
    # LSTM

    self.lstm_in_features = 2*self.features

    self.lstm = nn.LSTM(self.lstm_in_features, self.lstm_hidden_state_features)

    # Step 6
    # Softmax

    self.to_out_1 = nn.Linear(self.lstm_hidden_state_features, 7)
    self.to_out_2 = nn.Linear(7, out, bias=False)

    self.prob = nn.Softmax(dim=1)

    for param in self.parameters():
      if len(param.shape) >= 2:
        param = self.he(param)



  def forward(self, x):
    # x = Batches X Time X Channels

    N = x.shape[0] # Batches
    T = x.shape[1] # Time


    # Step 1
    # Cursory vision convolution
    pad_amt = 40

    pad = torch.zeros(N, pad_amt, self.features).to(device)

    padded_x = torch.cat((pad, torch.cat((x, pad), dim=1)), dim=1) # catting in time

    # padded_x.requires_grad = True

    window_centers = [] # AKA 64d Convolve Centers

    n=0 # strides
    while(True):
      next_center = pad_amt + 1 + n * self.stride
      if next_center > (pad_amt + T + 1): # If our center isnt in real data
        break;
      else:
        window_centers.append(next_center)
        n += 1

    window_starts = [self.start_and_end_from_center((64 - 1)/2, i)[0] for i in window_centers]
    window_ends = [self.start_and_end_from_center((64 - 1)/2, i)[1] for i in window_centers]

    midchannel_centers = []
    narrowchannel_centers = []
    for start in window_starts:
      for n in range(8): # 8 strides of 8 -> 64 units
        midchannel_centers.append(start + n * 8)

      for n in range(32): # 32 strides of 2 -> 64 units
        narrowchannel_centers.append(start + n * 2)


    wide_convs = []
    mid_convs = []
    narrow_convs = []

    for i in window_centers:
      for j in range(8): # Hard coded 8 here
        K = self.conv_kernels_64d[j]
        B = self.conv_biases_64d[j].repeat(N,1) # Repeat here
        conv = self.convolve(K, padded_x, i) # Features x T
        conv += B
        wide_convs.append(conv)

    for i in midchannel_centers:
      conv = self.convolve(self.conv_kernel_16d, padded_x, i)
      conv += self.conv_bias_16d.repeat(N,1)
      mid_convs.append(conv)

    for i in narrowchannel_centers:
      conv = self.convolve(self.conv_kernel_8d, padded_x, i)
      conv += self.conv_bias_8d.repeat(N,1)
      narrow_convs.append(conv)

    wide_convs = torch.stack(wide_convs, dim=1).to(device)
    mid_convs = torch.stack(mid_convs, dim=1).to(device)
    narrow_convs = torch.stack(narrow_convs, dim=1).to(device)

    narrow_convs = self.four_max_pool(torch.transpose(narrow_convs, 1,2)) # Inp = N x C x L now
    narrow_convs = torch.transpose(narrow_convs,1,2) # Back to N x L x C


    wide_convs = self.bn1(self.av(wide_convs))
    mid_convs = self.bn1(self.av(mid_convs))
    narrow_convs = self.bn1(self.av(narrow_convs))

    residual_vectors = self.get_residual_vectors(narrow_convs, mid_convs, wide_convs)


    wide_convs = 5 * self.drop(wide_convs)
    mid_convs = self.drop(mid_convs)
    narrow_convs = self.drop(narrow_convs)





    # convs = N x L x C




    #####
    # STEP 1.5
    # Compress our 5vectors to 3 vectors, combining mag and uncertainty indices

    # if not (wide_convs.shape[1] == mid_convs.shape[1] == narrow_convs.shape[1]):
    #   raise Exception("Step 1 output mismatch")


    # compacting_fc = lambda x: self.gelu(torch.matmul(self.gelu(torch.matmul(x, self.certainty_fc_1)), self.certainty_fc_2))

    # wide_col_0 = compacting_fc(wide_convs[:, 0:2])
    # wide_col_1 = compacting_fc(wide_convs[:, 2:4])

    # mid_col_0 = compacting_fc(mid_convs[:, 0:2])
    # mid_col_1 = compacting_fc(mid_convs[:, 2:4])

    # narrow_col_0 = compacting_fc(mid_convs[:, 0:2])
    # narrow_col_1 = compacting_fc(mid_convs[:, 2:4])

    # wide_convs = torch.stack((wide_col_0, wide_col_1, wide_convs[:, -1].unsqueeze(1)), dim=1)
    # mid_convs = torch.stack((mid_col_0, mid_col_1, mid_convs[:, -1].unsqueeze(1)), dim=1)
    # narrow_convs = torch.stack((narrow_col_0, narrow_col_1, narrow_convs[:, -1].unsqueeze(1)), dim=1)

    # wide_convs = wide_convs.squeeze()
    # mid_convs = mid_convs.squeeze()
    # narrow_convs = narrow_convs.squeeze()


    #####
    # STEP 2
    # Second Convolution

    pad_amt = 10
    stride = 1

    results = []

    for x in (wide_convs, mid_convs, narrow_convs):
      T = x.shape[1]
      ker = None
      bias = None
      if len(results) == 0:
        ker = self.widechannel_conv_kernel
        bias = self.widechannel_conv_bias.repeat(N, 1)

      elif len(results) == 1:
        ker = self.midchannel_conv_kernel
        bias = self.midchannel_conv_bias.repeat(N, 1)

      elif len(results) == 2:
        ker = self.narrowchannel_conv_kernel
        bias = self.narrowchannel_conv_bias.repeat(N, 1)


      pad = torch.zeros(N, pad_amt, self.features).to(device)
      padded_x = torch.cat((pad, torch.cat((x, pad), dim=1)), dim=1)


      result = []

      next = pad_amt
      while next <= (pad_amt + T - 1):
        v = bias + self.convolve(ker, padded_x, next)
        result.append(v)
        next += stride


      results.append(torch.stack(result, dim=1))


    wide_convs = self.bn2(self.av(results[0]))
    mid_convs = self.bn2(self.av(results[1]))
    narrow_convs = self.bn2(self.av(results[2]))

    wide_convs = self.drop(wide_convs)
    mid_convs = self.drop(mid_convs)
    narrow_convs = self.drop(narrow_convs)



    ####
    # Step 2.5
    # Max Pooling Pairs
    wide_convs = self.pair_max_pool(torch.transpose(wide_convs, 1,2)) # to N x C x L
    mid_convs = self.pair_max_pool(torch.transpose(mid_convs, 1,2))
    narrow_convs = self.pair_max_pool(torch.transpose(narrow_convs, 1,2))

    wide_convs = torch.transpose(wide_convs, 1,2) # to N x L x C
    mid_convs = torch.transpose(mid_convs, 1,2)
    narrow_convs = torch.transpose(narrow_convs, 1,2)


    ####
    # Step 3

    # Now each original window reigon is each corresponding 2 rows from all 3 tensors
    # 2 rows evenly divides all possible resulting lengths

    if not (wide_convs.shape[1] == mid_convs.shape[1] == narrow_convs.shape[1]):
      raise Exception("Step 3 output mismatch")


    hidden = []
    L = wide_convs.shape[1]

    for n in range(L // 4): # Now each sliding window corresponds to 2 rows
      wide = wide_convs[:, 4*n:4*n + 4].reshape(N, 4*self.features)
      mid = mid_convs[:, 4*n:4*n + 4].reshape(N, 4*self.features)
      narrow = narrow_convs[:, 4*n:4*n + 4].reshape(N, 4*self.features)

      flat = torch.cat((wide, mid, narrow), dim=1).to(device)
      if flat.shape[0] != N or flat.shape[1] != 3*4*self.features:
        raise Exception("Flat shape err")

      layer_1 = self.hidden_fc_1(flat)
      layer_1 = self.av(layer_1)

      # RESIDUAL CONNECTION !
      res = torch.autograd.Variable(self.residual * residual_vectors[n])
      layer_1 = (1-self.residual) * layer_1
      layer_1 = layer_1 + res


      layer_2 = self.hidden_fc_2(layer_1)
      layer_2 = self.av(layer_2)

      hidden.append(layer_2)


    hidden = torch.stack(hidden, dim=0) # Results in L x N x Hidden

    hidden = self.drop(hidden)


    ####
    # Step 6: LSTM

    _, (final_hidden_state, c_n) = self.lstm(hidden)

    final_hidden_state = self.av(final_hidden_state.squeeze())
    final_hidden_state = self.bn3(final_hidden_state)

    final_hidden_state = self.drop(final_hidden_state)


    final_layer = self.to_out_1(final_hidden_state)
    final_layer = self.av(final_layer)

    final_layer = self.drop(final_layer)

    final_layer = self.to_out_2(final_layer)

    classes = self.prob(final_layer)


    return F.log_softmax(final_layer, dim=1)


  def start_and_end_from_center(self, width, i):
    start = i - np.ceil(width)
    end = i + np.floor(width) + 1
    return (int(start), int(end))



  def convolve(self, Kernel, Data, i):

    # i is center index so we take equal on either side
    T = Kernel.shape[1]
    each_side = (T - 1) / 2

    # Moves it backwards 1 if kernel is even
    start, end = self.start_and_end_from_center(each_side, i)

    adj_data = Data[:, start:end, :]

    # So now K is 5 x L and adj_data is N x L X 5


    N = adj_data.shape[0]

    m = torch.bmm(Kernel.repeat(N,1,1), adj_data) # bij, bjk -> bik Slightly faster than einsum

    # m = torch.einsum("ij, bjk -> bik", Kernel, adj_data) #identical batch matmul

    diag = torch.einsum("bii->bi", m)
    # diags = []
    # for res in m:
    #   diags.append(torch.diag(res))


    return diag

  def get_residual_vectors(self, narrow, mid, wide):
    # Should be N x 8m x Features
    # print(narrow.shape, mid.shape, wide.shape)
    N = wide.shape[0]
    L = wide.shape[1]

    apply_avg = lambda matrix: torch.transpose(self.__avgpool__(torch.transpose(matrix, 1,2)), 1,2)

    vectors = []

    if L % 8:
      raise Exception("Something went wrong, convs are not mult of 8")
    for i in range(L // 8):

      n = torch.autograd.Variable(narrow[:, 8*i:8*(i+1), :])
      m = torch.autograd.Variable(mid[:, 8*i:8*(i+1), :])
      w = torch.autograd.Variable(wide[:, 8*i:8*(i+1), :])

      n = apply_avg(n)
      m = apply_avg(m)
      w = apply_avg(w)

      n = n.reshape(N, 2*self.features)
      m = m.reshape(N, 2*self.features)
      w = w.reshape(N, 2*self.features)
      # gives N x 2 * Features

      vectors.append(torch.autograd.Variable(torch.cat((n,m,w), dim=1))) # N x 6 * features

    return vectors

class FluxAnomalyPredictionTF(nn.Module):
  def __init__(self, stride, dropout, bn=False, features=3, residual=0, out=4):
    super().__init__()

    self.stride = stride
    self.features = features
    self.residual = residual
    self.transformer_hidden_dim = 2*features # Needed for batchnorm 3, so defined here

    # Need Dropouts, Activation fns
    self.he = lambda x: nn.init.kaiming_normal_(x, nonlinearity='relu')
    self.av = nn.ReLU()
    self.drop = nn.Dropout(dropout)

    self.batchnorm1 = lambda x: x
    self.batchnorm2 = lambda x: x
    self.batchnorm3 = lambda x: x

    if bn:
      self.batchnorm1 = nn.BatchNorm1d(features)
      self.batchnorm2 = nn.BatchNorm1d(features)
      self.batchnorm3 = nn.BatchNorm1d(self.transformer_hidden_dim)

      self.bn1 = lambda x: torch.transpose(self.batchnorm1(torch.transpose(x,1,2)),1,2)
      self.bn2 = lambda x: torch.transpose(self.batchnorm2(torch.transpose(x,1,2)),1,2)
      self.bn3 = lambda x: self.batchnorm3(x)
    else:
      self.bn1 = lambda x: x
      self.bn2 = lambda x: x
      self.bn3 = lambda x: x


    # Step 1

    # 5 x 64 X 64 x 5 -diag-> R^5 Vector
    self.conv_kernels_64d = nn.ParameterList([self.he(torch.randn(features, 64)) for i in range(8)])
    self.conv_biases_64d = nn.ParameterList([torch.randn(features) for i in range(8)])

    self.conv_kernel_16d = nn.Parameter(self.he(torch.randn(features, 16)))
    self.conv_bias_16d = nn.Parameter(torch.randn(features))

    self.conv_kernel_8d = nn.Parameter(self.he(torch.randn(features,8)))
    self.conv_bias_8d = nn.Parameter(torch.randn(features))

    self.four_max_pool = nn.MaxPool1d(4)

    # Step 1.5
    # self.certainty_fc_1 = nn.Parameter(torch.zeros(2,3))
    # self.certainty_fc_2 = nn.Parameter(torch.zeros(3,1))

    # Step 2
    self.widechannel_conv_kernel = nn.Parameter(self.he(torch.randn(features, 3)))
    self.widechannel_conv_bias = nn.Parameter(torch.randn(features))

    self.midchannel_conv_kernel = nn.Parameter(self.he(torch.randn(features, 3)))
    self.midchannel_conv_bias = nn.Parameter(torch.randn(features))

    self.narrowchannel_conv_kernel = nn.Parameter(self.he(torch.randn(features, 3)))
    self.narrowchannel_conv_bias = nn.Parameter(torch.randn(features))


    # Step 3
    self.pair_max_pool = nn.MaxPool1d(2)
    # Then concat all vectors into v \in R^24

    # Step 3.5 Residual connection
    self.__avgpool__ = nn.AvgPool1d(4, stride=4)

    # Step 4
    self.hidden_fc_1 = nn.Linear(3*4*features, 3*2*features)
    self.hidden_fc_2 = nn.Linear(3*2*features, self.transformer_hidden_dim)

    # Step 5
    # Transformer
    self.transformer_encoder_layer = nn.TransformerEncoderLayer(d_model=self.transformer_hidden_dim, nhead=self.features)
    self.transformer_encoder = nn.TransformerEncoder(self.transformer_encoder_layer, 2)


    # Step 6
    # Softmax

    self.to_out_1 = nn.Linear(self.transformer_hidden_dim, 7)
    self.to_out_2 = nn.Linear(7, out, bias=False)

    self.prob = nn.Softmax(dim=1)

    for param in self.parameters():
      if len(param.shape) >= 2:
        param = self.he(param)



  def forward(self, x):
    # x = Batches X Time X Channels

    N = x.shape[0] # Batches
    T = x.shape[1] # Time

    if x.shape[2] != self.features:
      raise Exception("Feature dimension mismatch")

    # Step 1
    # Cursory vision convolution
    pad_amt = 40

    pad = torch.zeros(N, pad_amt, self.features).to(device)
  

    padded_x = torch.cat((pad, torch.cat((x, pad), dim=1)), dim=1) # catting in time

    # padded_x.requires_grad = True

    window_centers = [] # AKA 64d Convolve Centers

    n=0 # strides
    while(True):
      next_center = pad_amt + 1 + n * self.stride
      if next_center > (pad_amt + T + 1): # If our center isnt in real data
        break;
      else:
        window_centers.append(next_center)
        n += 1

    window_starts = [self.start_and_end_from_center((64 - 1)/2, i)[0] for i in window_centers]
    window_ends = [self.start_and_end_from_center((64 - 1)/2, i)[1] for i in window_centers]

    midchannel_centers = []
    narrowchannel_centers = []
    for start in window_starts:
      for n in range(8): # 8 strides of 8 -> 64 units
        midchannel_centers.append(start + n * 8)

      for n in range(32): # 32 strides of 2 -> 64 units
        narrowchannel_centers.append(start + n * 2)


    wide_convs = []
    mid_convs = []
    narrow_convs = []

    for i in window_centers:
      for j in range(8): # Hard coded 8 here
        K = self.conv_kernels_64d[j]
        B = self.conv_biases_64d[j].repeat(N,1) # Repeat here
        conv = self.convolve(K, padded_x, i) # Features x T
        conv += B
        wide_convs.append(conv)

    for i in midchannel_centers:
      conv = self.convolve(self.conv_kernel_16d, padded_x, i)
      conv += self.conv_bias_16d.repeat(N,1)
      mid_convs.append(conv)

    for i in narrowchannel_centers:
      conv = self.convolve(self.conv_kernel_8d, padded_x, i)
      conv += self.conv_bias_8d.repeat(N,1)
      narrow_convs.append(conv)

    wide_convs = torch.stack(wide_convs, dim=1).to(device)
    mid_convs = torch.stack(mid_convs, dim=1).to(device)
    narrow_convs = torch.stack(narrow_convs, dim=1).to(device)

    narrow_convs = self.four_max_pool(torch.transpose(narrow_convs, 1,2)) # Inp = N x C x L now
    narrow_convs = torch.transpose(narrow_convs,1,2) # Back to N x L x C


    wide_convs = self.bn1(self.av(wide_convs))
    mid_convs = self.bn1(self.av(mid_convs))
    narrow_convs = self.bn1(self.av(narrow_convs))

    residual_vectors = self.get_residual_vectors(narrow_convs, mid_convs, wide_convs)


    wide_convs = 5 * self.drop(wide_convs)
    mid_convs = self.drop(mid_convs)
    narrow_convs = self.drop(narrow_convs)



    #####
    # STEP 2
    # Second Convolution

    pad_amt = 10
    stride = 1

    results = []

    for x in (wide_convs, mid_convs, narrow_convs):
      T = x.shape[1]
      ker = None
      bias = None
      if len(results) == 0:
        ker = self.widechannel_conv_kernel
        bias = self.widechannel_conv_bias.repeat(N, 1)

      elif len(results) == 1:
        ker = self.midchannel_conv_kernel
        bias = self.midchannel_conv_bias.repeat(N, 1)

      elif len(results) == 2:
        ker = self.narrowchannel_conv_kernel
        bias = self.narrowchannel_conv_bias.repeat(N, 1)


      pad = torch.zeros(N, pad_amt, self.features).to(device)
      padded_x = torch.cat((pad, torch.cat((x, pad), dim=1)), dim=1)


      result = []

      next = pad_amt
      while next <= (pad_amt + T - 1):
        v = bias + self.convolve(ker, padded_x, next)
        result.append(v)
        next += stride


      results.append(torch.stack(result, dim=1))


    wide_convs = self.bn2(self.av(results[0]))
    mid_convs = self.bn2(self.av(results[1]))
    narrow_convs = self.bn2(self.av(results[2]))

    wide_convs = self.drop(wide_convs)
    mid_convs = self.drop(mid_convs)
    narrow_convs = self.drop(narrow_convs)



    ####
    # Step 2.5
    # Max Pooling Pairs
    wide_convs = self.pair_max_pool(torch.transpose(wide_convs, 1,2)) # to N x C x L
    mid_convs = self.pair_max_pool(torch.transpose(mid_convs, 1,2))
    narrow_convs = self.pair_max_pool(torch.transpose(narrow_convs, 1,2))

    wide_convs = torch.transpose(wide_convs, 1,2) # to N x L x C
    mid_convs = torch.transpose(mid_convs, 1,2)
    narrow_convs = torch.transpose(narrow_convs, 1,2)


    ####
    # Step 3

    # Now each original window reigon is each corresponding 2 rows from all 3 tensors
    # 2 rows evenly divides all possible resulting lengths

    if not (wide_convs.shape[1] == mid_convs.shape[1] == narrow_convs.shape[1]):
      raise Exception("Step 3 output mismatch")


    hidden = []
    L = wide_convs.shape[1]

    for n in range(L // 4): # Now each sliding window corresponds to 2 rows
      wide = wide_convs[:, 4*n:4*n + 4].reshape(N, 4*self.features)
      mid = mid_convs[:, 4*n:4*n + 4].reshape(N, 4*self.features)
      narrow = narrow_convs[:, 4*n:4*n + 4].reshape(N, 4*self.features)

      flat = torch.cat((wide, mid, narrow), dim=1).to(device)
      if flat.shape[0] != N or flat.shape[1] != 3*4*self.features:
        raise Exception("Flat shape err")

      layer_1 = self.hidden_fc_1(flat)
      layer_1 = self.av(layer_1)

      # RESIDUAL CONNECTION !
      res = torch.autograd.Variable(self.residual * residual_vectors[n])
      layer_1 = (1-self.residual) * layer_1
      layer_1 = layer_1 + res

      layer_2 = self.hidden_fc_2(layer_1)
      layer_2 = self.av(layer_2)

      hidden.append(layer_2)

    hidden = torch.stack(hidden, dim=0) # Results in (Divided L) x N x Hidden
    hidden = self.drop(hidden)


    seq = self.transformer_encoder(hidden) # same shape as hidden



    transformed = nn.AvgPool1d(seq.shape[0])(torch.transpose(seq,0,2)) # Pooling happens on last dim
    transformed = torch.squeeze(transformed, dim=2)
    transformed = torch.transpose(transformed, 0, 1) # Should be N x Hidden

    transformed = self.av(transformed)
    transformed = self.bn3(transformed)
    transformed = self.drop(transformed)

    final_layer = self.to_out_1(transformed)
    final_layer = self.av(final_layer)
    final_layer = self.drop(final_layer)
    final_layer = self.to_out_2(final_layer)

    # classes = self.prob(final_layer) # Dont use, use CEL instead

    return F.log_softmax(final_layer, dim=1)


  def start_and_end_from_center(self, width, i):
    start = i - np.ceil(width)
    end = i + np.floor(width) + 1
    return (int(start), int(end))



  def convolve(self, Kernel, Data, i):

    # i is center index so we take equal on either side
    T = Kernel.shape[1]
    each_side = (T - 1) / 2

    # Moves it backwards 1 if kernel is even
    start, end = self.start_and_end_from_center(each_side, i)

    adj_data = Data[:, start:end, :]

    # So now K is 5 x L and adj_data is N x L X 5


    N = adj_data.shape[0]

    m = torch.bmm(Kernel.repeat(N,1,1), adj_data) # bij, bjk -> bik Slightly faster than einsum

    # m = torch.einsum("ij, bjk -> bik", Kernel, adj_data) #identical batch matmul

    diag = torch.einsum("bii->bi", m)
    # diags = []
    # for res in m:
    #   diags.append(torch.diag(res))


    return diag

  def get_residual_vectors(self, narrow, mid, wide):
    # Should be N x 8m x Features
    # print(narrow.shape, mid.shape, wide.shape)
    N = wide.shape[0]
    L = wide.shape[1]

    apply_avg = lambda matrix: torch.transpose(self.__avgpool__(torch.transpose(matrix, 1,2)), 1,2)

    vectors = []

    if L % 8:
      raise Exception("Something went wrong, convs are not mult of 8")
    for i in range(L // 8):

      n = torch.autograd.Variable(narrow[:, 8*i:8*(i+1), :])
      m = torch.autograd.Variable(mid[:, 8*i:8*(i+1), :])
      w = torch.autograd.Variable(wide[:, 8*i:8*(i+1), :])

      n = apply_avg(n)
      m = apply_avg(m)
      w = apply_avg(w)

      n = n.reshape(N, 2*self.features)
      m = m.reshape(N, 2*self.features)
      w = w.reshape(N, 2*self.features)
      # gives N x 2 * Features

      vectors.append(torch.autograd.Variable(torch.cat((n,m,w), dim=1))) # N x 6 * features

    return vectors

class lstm(nn.Module):
  def __init__(self, into, hidden, out):
    super().__init__()

    self.rl = nn.functional.sigmoid
    self.ls = nn.LSTM(into, hidden, batch_first=True)
    self.to_out = nn.Linear(hidden,out)
    self.prob = nn.Softmax(dim=1)

  def forward(self, x):
    _, (f, __) = self.ls(x)

    return self.to_out(f.squeeze())

class lstm2(nn.Module):
  def __init__(self, into, hidden, out):
    super().__init__()

    self.rl = nn.functional.sigmoid
    self.ls = nn.LSTM(into, hidden, batch_first=True, num_layers=2)
    self.to_out = nn.Linear(hidden,out)
    self.prob = nn.Softmax(dim=1)

  def forward(self, x):
    _, (f, __) = self.ls(x)

    return self.to_out(f[1, :].squeeze())

In [4]:
model = FluxAnomalyPredictionTF(20, 0.15, bn=False, features=3, residual=0).to(device)
# model = lstm(6, 20, 4).to(device)

model_pars = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_pars])

print(params)

82997


In [5]:
def padded_collate(tensors):
  datas = []
  labels = []
  for tensor, label in tensors:
    datas.append(tensor)
    labels.append(label)

  batched = nn.utils.rnn.pad_sequence(datas, batch_first=True).to(device)
  labels = torch.stack(labels, dim=0)


  return (batched, labels)



train = DataLoader(data_train, batch_size=len(data_train), shuffle=True, collate_fn=padded_collate)
valid = DataLoader(data_valid, batch_size=len(data_valid), shuffle=True, collate_fn=padded_collate)
# test = DataLoader(data_test, batch_size=None, shuffle=True)
next(iter(train))[0].shape

torch.Size([480, 999, 3])

# General Trainer

In [6]:
from time import perf_counter
loss_fn = nn.CrossEntropyLoss().to(device)

def compete(trainers, epochs, trainloader, validloader):
  # trainers are model-optimizer pairs

  progress_bar = tqdm(total=epochs, desc="Training Progress")
  loss_fn = nn.CrossEntropyLoss().to(device)


  trainloss = {}
  validloss = {}

  nullts = {}
  novats = {}
  pulsatingts = {}
  transitts = {}
  accuracyts = {}

  for name, _, __ in trainers:
    trainloss[name] = []
    validloss[name] = []
    nullts[name] = []
    novats[name] = []
    pulsatingts[name] = []
    transitts[name] = []
    accuracyts[name] = []


  for e in range(epochs):
    t1 = perf_counter()


    for name, model, optim in trainers:
      epoch_loss = []
      valid_loss = []

      null_correct = 0
      nova_correct = 0
      pulsating_correct = 0
      transit_correct = 0
      correct = 0


      novas = 0
      pulsators = 0
      transits = 0
      nulls = 0
      exs = 0

      for data, label in trainloader:
        model.train()
        out = model(data)

        optim.zero_grad()

        loss = loss_fn(out, label)
        epoch_loss.append(loss.item())

        loss.backward()
        print(out, label)
        print(loss)
        optim.step()


      for data, label in validloader:
        model.eval()
        out = model(data)
        loss = loss_fn(out, label)
        valid_loss.append(loss.item())
        i = torch.argmax(out, dim=1).cpu()
        j = torch.argmax(label, dim=1).cpu()

        for idx, jdx in zip(i,j):
          exs+=1
          if idx == jdx:
            correct += 1

          if jdx == 0:
            nulls += 1
            if idx == jdx:
              null_correct += 1

          if jdx == 1:
            novas += 1
            if idx == jdx:
              nova_correct +=1
          if jdx == 2:
            pulsators += 1
            if idx == jdx:
              pulsating_correct +=1
          if jdx == 3:
            transits += 1
            if idx == jdx:
              transit_correct +=1

      training_loss_epoch = np.mean(epoch_loss)
      validation_loss_epoch = np.mean(valid_loss)

      nullac = null_correct / (nulls + 0.000001)
      novacc = nova_correct / (novas + 0.000001)
      pulsatoracc = pulsating_correct / (pulsators + 0.00001)
      transitacc = transit_correct / (transits + 0.00001)
      accuracy = correct / exs

      trainloss[name].append(training_loss_epoch)
      validloss[name].append(validation_loss_epoch)

      nullts[name].append(nullac)
      novats[name].append(novacc)
      pulsatingts[name].append(pulsatoracc)
      transitts[name].append(transitacc)
      accuracyts[name].append(accuracy)



    dt = perf_counter() - t1
    if dt < 0.65:
      if e % int(1.75/dt) == 0:
        for name, _, __ in trainers:
          clear_output()
          fig = getprogressplot(trainloss[name], validloss[name], accuracyts[name], nullts[name], novats[name], pulsatingts[name], transitts[name], epochs, e)
          fig.update_layout(title=name+' Statistics: {}/{}'.format(e, epochs),
                        xaxis_title='Epochs',
                        yaxis_title='Loss',
                        width=650)
          display(fig)
    else:
      for name, _, __ in trainers:
        clear_output()
        fig = getprogressplot(trainloss[name], validloss[name], accuracyts[name], nullts[name], novats[name], pulsatingts[name], transitts[name], epochs, e)
        fig.update_layout(title=name+' Statistics: {}/{}'.format(e, epochs),
                      xaxis_title='Epochs',
                      yaxis_title='Loss',
                      width=1000)
        display(fig)

    progress_bar.update(1)

In [9]:
# model = FluxAnomalyPredictionTF(14, 0.15, bn=False, features=3, residual=0).to(device)
# model.train()
# model = FluxAnomalyPredictionLSTM(14, 0.15, bn=False, features=3, residual=0).to(device)
# model = lstm(3,10,4).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

compete([("TF", model, optimizer)], 720, train, valid)

Training Progress:   2%|▏         | 14/720 [00:32<27:07,  2.30s/it]

KeyboardInterrupt: 

In [None]:
with open(ROOT + "models/tf_str14_dp015_bnF_f3_res0.pt", "wb") as f:
  torch.save(model.state_dict(),f)

# Training Loop (Single)

In [7]:
EPOCHS = 500
lr = 0.007

model = FluxAnomalyPredictionTF(20, 0.15, bn=False, features=3, residual=0).to(device)

optim = torch.optim.Adam(model.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss().to(device)

progress_bar = tqdm(total=EPOCHS, desc="Training Progress")

trainloss = []
validloss = []

nullts = []
novats = []
pulsatingts = []
transitts = []
accuracyts = []

for e in range(EPOCHS):
  epoch_loss = []
  valid_loss = []

  null_correct = 0
  nova_correct = 0
  pulsating_correct = 0
  transit_correct = 0
  correct = 0


  novas = 0
  pulsators = 0
  transits = 0
  nulls = 0
  exs = 0

  for data, label in train:

    model.train()
    out = model(data)

    loss = loss_fn(out, label)

    epoch_loss.append(loss.item())

    loss.backward()
    optim.step()
    optim.zero_grad()


  for data, label in valid:
    model.eval()
    out = model(data)
    loss = loss_fn(out, label)
    valid_loss.append(loss.item())
    i = torch.argmax(out, dim=1).cpu()
    j = torch.argmax(label, dim=1).cpu()

    for idx, jdx in zip(i,j):
      exs += 1
      if idx == jdx:
        correct += 1

      if jdx == 0:
        nulls += 1
        if idx == jdx:
          null_correct += 1

      if jdx == 1:
        novas += 1
        if idx == jdx:
          nova_correct +=1
      if jdx == 2:
        pulsators += 1
        if idx == jdx:
          pulsating_correct +=1
      if jdx == 3:
        transits += 1
        if idx == jdx:
          transit_correct +=1

  training_loss_epoch = np.mean(epoch_loss)
  validation_loss_epoch = np.mean(valid_loss)

  nullac = null_correct / (nulls + 0.000001)
  novacc = nova_correct / (novas + 0.000001)
  pulsatoracc = pulsating_correct / (pulsators + 0.00001)
  transitacc = transit_correct / (transits + 0.00001)
  accuracy = correct / exs

  trainloss.append(training_loss_epoch)
  validloss.append(validation_loss_epoch)

  nullts.append(nullac)
  novats.append(novacc)
  pulsatingts.append(pulsatoracc)
  transitts.append(transitacc)
  accuracyts.append(accuracy)


  progress_bar.update(1)
  p = getprogressplot(trainloss, validloss, accuracyts, nullts, novats, pulsatingts, transitts, EPOCHS, e)
  clear_output(wait=False)
  display(p)

  print("Epoch ", e, ": ", training_loss_epoch)
  print("nulls:", nullac, "novas: ", novacc, "pulsators: ", pulsatoracc, "transits: ", transitacc)

x = range(EPOCHS)

p = getprogressplot(trainloss, validloss, accuracyts, nullts, novats, pulsatingts, transitts, EPOCHS, e)
clear_output(wait=True)
display(p)
print("Epoch ", e, ": ", training_loss_epoch)

with open(ROOT + "models/model.pt", "wb") as f:
  torch.save(model.named_parameters(),f)

Epoch  499 :  0.585808219844024


TypeError: cannot pickle 'generator' object

In [15]:
with open("models/model.pt", "wb") as f:
    torch.save(model.state_dict(), f)

# Diagnosis

In [None]:
first = next(iter(train))

In [None]:
model = FluxAnomalyPrediction(32, 0, True, residual=0).to(device)
loss = loss_fn(model(first[0]), first[1])
optim.zero_grad()
loss.backward()
# for p, t in model.cpu().named_parameters():
#   print(p, t.grad)
plot_grad_flow(model.cpu().named_parameters())
plt.show()

model = FluxAnomalyPrediction(32, 0, True, residual=0.5).to(device)
loss = loss_fn(model(first[0]), first[1])
optim.zero_grad()
loss.backward()
# for p, t in model.cpu().named_parameters():
#   print(p, t.grad)
plot_grad_flow(model.cpu().named_parameters())
plt.show()

In [None]:
model = FluxAnomalyPrediction(32, 0.1).to(device)


with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
    with record_function("dd"):
        ex = next(iter(train))
        out = model(ex[0])
        loss = loss_fn(out, ex[1])
        loss.backward()
        for name, param in model.cpu().named_parameters():
          print(name, param.grad)
        plot_grad_flow(model.cpu().named_parameters())

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

# Testing / Toy Model

In [None]:
class bigboy(nn.Module):
  def __init__(self, out):
    super().__init__()
    self.av = nn.ReLU()

    self.to_one = nn.Parameter(torch.zeros(6))

    self.fc1 = nn.Linear(350, 200)
    self.fc2 = nn.Linear(200, 100)
    self.fc3 = nn.Linear(100, 10)



    self.fc4 = nn.Linear(10, out)

    self.prob = nn.Softmax(dim=1)

  def forward(self, x):

    # x = torch.einsum("bij,j->bi", x, self.to_one)
    x = x[:, :, 0]
    T = x.shape[1]

    pad = torch.zeros(x.shape[0], 350 - T).to(device)

    padded = torch.cat((x, pad), dim=1)


    current = padded

    for f in (self.fc1, self.fc2, self.fc3, self.fc4):
      current = self.av(f(current))

    return self.prob(current)








In [None]:
class PseudoSet(Dataset):
  def __init__(self, classes):
    self.classes = classes

    self.all = []
    for i, class_ in enumerate(self.classes):
      label = torch.zeros(len(self.classes))
      label[i] = 1
      for ex in class_:
        self.all.append((ex, label))


  def __getitem__(self, idx):
    return self.all[idx]

  def __len__(self):
    return len(self.all)

with open(ROOT + "datasets/toy_data_train.pt", "rb") as f:
  toydata_train = torch.load(f, map_location=device)
with open(ROOT + "datasets/toy_data_valid.pt", "rb") as f:
  toydata_valid = torch.load(f, map_location=device)

toytrain = DataLoader(toydata_train, batch_size=len(toydata_train), collate_fn=padded_collate, shuffle=True)
toyvalid = DataLoader(toydata_valid, batch_size=len(toydata_valid), collate_fn=padded_collate, shuffle=True)

In [None]:
ex = toydata_train[100][0].cpu()
plt.scatter(ex[:, -2], ex[:, 0])
print(toydata_train[200][1])

In [None]:
model = FluxAnomalyPredictionLSTM(15, 0.15, residual=0, features=6, out=3).to(device)
# model = lstm(6, 20, 3).to(device)
# model = bigboy(3).to(device)


optim = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=10**-6.5)
loss_fn = nn.CrossEntropyLoss().to(device)

trainers = [
    ("Flux Anomaly Predictor", model, optim)
]

compete(trainers, 150, train, valid)
