In [1]:
import os, sys
sys.path.append('./mylayers')
sys.path.append('./utils')
sys.path.append('./data')

In [2]:
from mgconv import MGConv
from resmgunit import ResMGUnit

from avg_eigmodule import avg_eigmodule
from prepare_cosine_matrix import prepare_cosine_matrix

from specradloss import SpecRadLoss
from nuclearnormloss import NuclearNormLoss
from maxsingularvalueloss import MaxSingularValueLoss

from normal_toeplitz_generator import normal_toeplitz_generator
from comp_gmres_iters import comp_gmres_iters

In [3]:
import numpy as np
import wandb

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split

from tqdm import tqdm

In [4]:
if torch.backends.mps.is_available():
    device = 'mps'
elif torch.cuda.is_available():
    device = 'cuda:0'
else:
    device = 'cpu'

# Data preparation

In [5]:
dims_datasets = {}
dims_dataloaders = {}

PATH_TO_DATA = './data/DATA/'
MODE = 'Random'

In [6]:
dims_datasets['Train64'] = torch.imag(torch.load(f'{PATH_TO_DATA}Matrix64/{MODE}ToeplitzTrain10000.pth'))
dims_datasets['Train128'] = torch.imag(torch.load(f'{PATH_TO_DATA}Matrix128/{MODE}ToeplitzTrain10000.pth'))
dims_datasets['Train256'] = torch.imag(torch.load(f'{PATH_TO_DATA}Matrix256/{MODE}ToeplitzTrain10000.pth'))

dims_datasets['Valid64'] = torch.imag(torch.load(f'{PATH_TO_DATA}Matrix64/{MODE}ToeplitzValid1000.pth'))
dims_datasets['Valid128'] = torch.imag(torch.load(f'{PATH_TO_DATA}Matrix128/{MODE}ToeplitzValid1000.pth'))
dims_datasets['Valid256'] = torch.imag(torch.load(f'{PATH_TO_DATA}Matrix256/{MODE}ToeplitzValid1000.pth'))

In [7]:
for key in dims_datasets.keys():
    dims_datasets[key] = torch.utils.data.TensorDataset(dims_datasets[key][:,None,:].to(device))

In [8]:
for key in dims_datasets.keys():
    if key[:6] == 'Train':
        dims_dataloaders[key] = DataLoader(dims_datasets[key], batch_size=32, shuffle=True, drop_last=True)
    else:
        dims_dataloaders[key] = DataLoader(dims_datasets[key], batch_size=32, shuffle=False)

# Network

In [9]:
class MyNetwork(nn.Module):
    
    def __init__(self, depth,
                       channels):
        
        super().__init__()
        
        assert depth == len(channels)
        assert depth >= 2
        
        self.depth = depth
        self.channels = channels
        
        self.coarsers = nn.ModuleList([])
        for i in range(1, depth):
            self.coarsers.append(nn.Upsample(scale_factor=0.5**i,
                                             mode='linear'))
        self.convs = nn.ModuleList([])
        for i in range(depth):
            self.convs.append(nn.Conv1d(1, channels[i],
                                        kernel_size=3, padding=1))
        
        self.mg_1 = ResMGUnit(depth, channels)
        
        self.ac_1 = nn.ModuleList([nn.ReLU() for i in range(depth)])
        
        self.mg_2 = ResMGUnit(depth, channels)
        
        self.ac_2 = nn.ModuleList([nn.ReLU() for i in range(depth)])
                                   
        self.mg_3 = ResMGUnit(depth, channels)
        
    def forward(self, x):
        
        meshes = [0 for _ in range(self.depth)]
        meshes[0] = x
        for i in range(1, self.depth):
            meshes[i] = self.coarsers[i - 1](x)
        
        upd_meshes = [0 for _ in range(self.depth)]
        for i in range(self.depth):
            upd_meshes[i] = self.convs[i](meshes[i])
            
        y = self.mg_1(upd_meshes)
        
        z = [0 for _ in range(self.depth)]
        for i in range(self.depth):
            z[i] = self.ac_1[i](y[i])
        
        w = self.mg_2(z)
        
        u = [0 for _ in range(self.depth)]
        for i in range(self.depth):
            u[i] = self.ac_2[i](y[i])
            
        v = self.mg_3(u)
        
        return v

In [10]:
class BigPipeline(nn.Module):
    
    def __init__(self, depth, 
                       channels):
        
        super().__init__()
        
        self.depth = depth
        self.channels = channels
        
        self.net = MyNetwork(depth, channels)
        
    def forward(self, x):
        
        matrix_dim = 2 * x.shape[2]
        mat = prepare_cosine_matrix(matrix_dim)
        
        x_expanded = torch.zeros((x.shape[0], x.shape[1], matrix_dim))
        x_expanded[...,:matrix_dim // 2] = x[...,:matrix_dim // 2]
        x_expanded[...,matrix_dim // 2:-1] = torch.flip(x[...,1:matrix_dim // 2],
                                                        dims=[2])
        
        y = (mat @ x_expanded[:,0,:].T).T[:,None,:]
        
        z = self.net(y)[0]
        
        u = 1 / matrix_dim * (mat @ z[:,0,:].T).T[:,None,:]
        
        return u[...,:matrix_dim // 2]

# Train loop

In [11]:
@torch.no_grad()
def evaluate(net, 
             valid_dataloader, 
             device):
    
    net.eval()
    
    total_loss, total_avg_eig_true, total_avg_eig_pred = 0., [], []
    count = 0
    
    for X_batch in valid_dataloader:
    
        X_batch = X_batch[0]
        out = net(X_batch)
        
        bs = out.shape[0]
        
        #total_loss += bs * torch.mean((out - y_true) ** 2)
        #total_avg_eig_true += avg_eigmodule(X_batch, y_true)
        total_avg_eig_pred += avg_eigmodule(X_batch, out)
        
        count += bs

    return np.mean(total_avg_eig_pred)

def train(epoch_num, 
          net, 
          optimizer, 
          criterion,
          scheduler,
          train_dataloaders, 
          valid_dataloader, 
          device, 
          name):

    wandb.init(project="sirius", name=name)
    global_step = 0
    net = net.to(device)

    iters_per_epoch = len(train_dataloaders[0])

    for epoch in tqdm(range(epoch_num)):
        
        for inputs in zip(*train_dataloaders):
            
            optimizer.zero_grad()
            loss = torch.tensor(0.)
            total_bs = 0
            
            for cur_input in inputs:
                
                cur_input[0] = cur_input[0].to(device)
                
                output = torch.squeeze(net(cur_input[0]))
                cur_input = torch.squeeze(cur_input[0])
                
                bs = output.shape[0]
                total_bs += bs
            
                loss += criterion(cur_input, output) * bs
                
            loss = loss / total_bs
            loss.backward()
            
            optimizer.step()
            
            wandb.log({"train/loss": loss.item()}, step=global_step)

            if global_step % 100 == 0:

                avg_valid_eig_module_pred  = evaluate(net, valid_dataloader, device)
                net.train()
                
                wandb.log({"eval/avg_valid_eig_module_pred": avg_valid_eig_module_pred}, step=global_step)

            global_step += 1
            
        scheduler.step()

    wandb.finish()

In [12]:
net = BigPipeline(3, [20, 16, 12])
#criterion = SpecRadLoss(n=4,n_samples=1000, alpha_reg=0.)
#criterion_main = MaxSingularValueLoss(10)
criterion_reg = NuclearNormLoss()
#criterion = lambda x, y: criterion_main(x, y) + 1e-2 * criterion_reg(x, y)
optimizer = optim.Adam(net.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

In [13]:
def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

get_n_params(net)

33888

In [14]:
train(
    epoch_num=5,
    net=net,
    optimizer=optimizer,
    criterion=criterion_reg,
    scheduler=scheduler,
    train_dataloaders=[dims_dataloaders['Train64'], dims_dataloaders['Train128'], dims_dataloaders['Train256']],
    valid_dataloader=dims_dataloaders['Valid64'],
    device=device,
    name="test_upd_128")

[34m[1mwandb[0m: Currently logged in as: [33ms02210401[0m ([33mcmcmsu[0m). Use [1m`wandb login --relogin`[0m to force relogin














 40%|██████████████████████████████████████████████████████▊                                                                                  | 2/5 [2:15:17<3:22:56, 4058.69s/it]


KeyboardInterrupt: 

In [None]:
train(
    epoch_num=100,
    net=net,
    optimizer=optimizer,
    criterion=criterion,
    scheduler=scheduler,
    train_dataloader=train_dataloader,
    valid_dataloader=val_dataloader,
    device=device,
    name="test_upd_128")

[34m[1mwandb[0m: Currently logged in as: [33ms02210401[0m ([33mcmcmsu[0m). Use [1m`wandb login --relogin`[0m to force relogin


 10%|█████████████▋                                                                                                                           | 10/100 [18:24<2:46:44, 111.16s/it]

In [16]:
torch.save({'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()}, 
           'best_practice.path')

In [None]:
stds = 2 * np.ones(128) / np.arange(1, 128 + 1)
comp_gmres_iters(net, normal_toeplitz_generator(stds, 45))

With:  1000
Without:  29
With:  1000
Without:  16
With:  10
Without:  4
With:  1000
Without:  6


In [None]:
checkpoint = torch.load('best_practice.path')
net.load_state_dict(checkpoint['model_state_dict'])