In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import deque
import matplotlib.pyplot as plt

#Needed for training and evaluation
from losses import *
from RandomMatrixDataSet import get_sample
from RandomMatrixDataSet import RandomMatrixDataSet
from RandomMatrixDataSet import SingularvalueMatrix
from RandomMatrixDataSet import EigenMatrix
from train import train_on_batch
from train import run_training
from evaluation import *
from plotting import plot_loss_logs
from plotting import error_histogram
from plotting import plot_mean_identity_approx
import math
import time
import random
from typing import Tuple
from typing import Optional
from typing import Callable

import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F

from transformer import TransformerModel
from transformer import PositionalEncoding
from transformer import device

#Seed and looks
torch.random.seed = 1234
plt.rcParams['figure.figsize'] = [14, 6]
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Palatino"],
})

torch.set_printoptions(precision=8)

### Here I generate a src-tgt pair to validate the pipe-line (or a test set)

In [8]:
def get_batch(matrix_parameters):
    M = get_sample(matrix_parameters)
    M.compute_labels()
    # print(f'==[ M.X:\n{M.X}')
    # data = torch.flatten(M.X, start_dim = 2)
    # target = torch.flatten(M.Y, start_dim = 2)
    # return data, target
    # return M.X, target
    return M.X, M.Y

# This is the loss Ive used in previous training
mseloss = nn.MSELoss()

def MSE_loss(out, tgt, src):
   
    return mseloss(out, tgt)

def relative_inv_MSE_loss(out, tgt, src):
    #Normalize with batch square mean?
    # id_approx = torch.matmul(predicted,x)
    # print(f'==[ predicted: {predicted.shape}')
    # print(f'==[ x.shape: {src.shape}')
    
    id = torch.eye(src.shape[-1])[None,:,:].repeat(src.shape[0],1,1)
    # print(f'==[ id:\n{id}')
    
    # print(f'==[ out.squeeze(1): {out.squeeze(1)}') 
    # print(f'==[ src.squeeze(1): {src.squeeze(1)}') 
    
    id_approx = torch.bmm(out.squeeze(1), src.squeeze(1))
    # print(f'==[ id_approx: {id_approx}')
    
    return (id_approx - id).square().mean()

def train(model: nn.Module, lossfun: Callable) -> None:
    model.train()  # turn on train mode
    total_loss = 0.
    log_interval = 200
    start_time = time.time()

    #num_batches = len(train_data) // bptt
    for i in range(iterations):
        # Matrices of random size between 3 and 7
        train_params['d'] = random.randint(3,7)
        src, tgt = get_batch(train_params)
        optimizer.zero_grad()
        out = model(src, tgt)
        loss = lossfun(out, tgt, src)
        loss.backward()
        optimizer.step()
        
        # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    
        total_loss += loss.item()
        
        if i % log_interval == 0 and i > 0:
            lr = scheduler.get_last_lr()[0]
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            #ppl = math.exp(cur_loss)
            print(f'| epoch {epoch:3d} | {i:5d} batches | '
                  f'lr {lr:02.4f} | ms/batch {ms_per_batch:5.2f} | '
                  f'loss {cur_loss:5.5f}')
            total_loss = 0
            # print(f'==[ out:\n{out}')
            start_time = time.time()
           
def evaluate(model: nn.Module, eval_parameters) -> float:
    model.eval()  # turn on evaluation mode
    total_loss = 0.
    with torch.no_grad():
        src, tgt = get_batch(eval_parameters)
        output = model(src, tgt)
        output = output.squeeze()
        
        total_loss += N * criterion(output_flat, tgt.reshape(-1,n_tokens)).item()
    return total_loss / (eval_parameters["N"] - 1)

## Initiate an instance

In [9]:
# Matrix parameters
m_size = 3 # Matrix size
k_size = (2,2) # Kernel size
max_seq_length = 128 # Maximum sequence length

# Model parameters
d_hid = 64 # Dimension of the feedforward network model in nn.TransformerEncoder
n_heads = 1  # number of heads in nn.MultiheadAttention
n_layers = 1  # Number of nn.TransformerEncoderLayer in nn.TransformerEncoder
dropout = 0.0  # dropout probability
           
model = TransformerModel(d_hid, n_heads, n_layers, dropout, max_seq_length, k_size).to(device)

# Training and evaluation parameters
N = 32 # Batch size
train_params = {
      "N": N,
      "d": m_size
}
best_val_loss = float('inf')
epochs = 20
iterations = 1000
best_model = None
eval_params = {
      "N": 100,
      "d": m_size
}

lr = 1e-3 # learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

# lossfun = MSE_loss
lossfun = relative_inv_MSE_loss

# ------------------------------------------------------------------------------
# Debug
# ------------------------------------------------------------------------------
src, tgt = get_batch(train_params)
# print(f'==[ Source:\n{src}')
# print(f'==[ Source shape:\n{src.shape}')
# print(f'==[ Target:\n{tgt}')
# print(f'==[ Target shape:\n{tgt.shape}')

input_ones = torch.ones((1,*src.shape[1:]), dtype=src.dtype)
divisor = F.fold(F.unfold(input_ones,kernel_size=k_size),kernel_size=k_size, output_size=src.shape[-2:])
# print(f'==[ divisor: {divisor}')

unfolded = F.unfold(tgt, kernel_size=k_size)
# print(f'==[ Unfolded:\n{unfolded}')
# print(f'==[ Unfolded shape:\n{unfolded.shape}')

folded = F.fold(unfolded, kernel_size=k_size, output_size=src.shape[-2:]) / divisor
# print(f'==[ Folded:\n{folded}')
# print(f'==[ Folded shape:\n{folded.shape}')

if not torch.equal(folded, tgt):
      raise ValueError("Invalid unfold/fold operation!")

pos_encoder= PositionalEncoding(math.prod(k_size))
src_pos_encoded = pos_encoder.forward(unfolded.permute(0,2,1))
# print(f'==[ src_pos_encoded:\n{src_pos_encoded}')
# print(f'==[ src_pos_encoded shape:\n{src_pos_encoded.shape}')

out = model(src, tgt)
# print(f'==[ out: {out}')
# print(f'==[ out shape: {out.shape}')

loss = lossfun(out, tgt, src)
# ------------------------------------------------------------------------------
# /Debug
# ------------------------------------------------------------------------------


In [None]:
for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train(model, lossfun)
   #  val_loss = evaluate(model, eval_parameters)
   #  print(f'==[ val_loss: {val_loss}')
   #  val_ppl = math.exp(val_loss)
   #  elapsed = time.time() - epoch_start_time
   #  print('-' * 89)
   #  print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '
   #       f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}')
   #  print('-' * 89)

   #  if val_loss < best_val_loss:
   #     best_val_loss = val_loss
   #     best_model = copy.deepcopy(model)

    scheduler.step()

In [None]:
out = model(src,tgt)

# print(out) #Just goes to some (or two) fixed values 
#print(t)
print(f'==[ out: {out}')
mia = torch.bmm(out.squeeze(1), src.squeeze(1))[0]
print(f'==[ mia: {mia}')
mia = mia.detach().numpy()       

In [None]:
fig = plt.figure()

ax = fig.add_subplot()

img = ax.imshow(mia, cmap='spring')
for i in range(m_size):
    for j in range(m_size):
        t = ax.text(j, i, round(mia[i, j], 4),
                        ha="center", va="center", color="black", fontsize=16)
    # Create colorbar
cbar = ax.figure.colorbar(img, ax=ax)
ax.set_xticks(np.arange(0, m_size, 1) + 0.5)
ax.set_yticks(np.arange(0, m_size, 1) + 0.5)
ax.set_title('Mean $f(X)X$', fontsize=18)
ax.set_xticklabels('')
ax.set_yticklabels('')
ax.set_title("A less underwhelming Transformer inverter", fontsize = 16)
ax.grid(color="w", linestyle='-', linewidth=3)
ax.tick_params(bottom=False, left=False);
plt.savefig("transverter.png")