In [1]:
import os
import numpy as np
import torch 
import argparse 
import torch.nn.functional as F 
from src.dataset import load_dataset_1d 
from src.utils import (
    init_records, save_hist, get_seed, rl2_error)
import json 
from tqdm import trange
import time

from src.model import MLP
from src.dd_gmg import DD_GMG1D
from src.green_net import GreenNet1D



Pseudo data

In [2]:
n = 9
bsz = 200
seq_len = 2**9+1
u = torch.randn(seq_len, bsz)
f = torch.randn(seq_len, bsz)

Speed Measure for Green 1D

In [3]:
ps = [0.01, 0.03, 0.05, 0.07, 0.1, 0.15, 0.25, 0.30, 0.4, 0.5, 0.7, 0.9, 1.0]

In [10]:
iters = 1000
warmup_iters = 5

cuda = False
if cuda:
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')


u = torch.randn(seq_len, bsz).to(device)
f = torch.randn(seq_len, bsz).to(device)

p = ps[3]
in_channels = 2
out_channels = 1
hidden_channels = 50 
layers = [in_channels] + [hidden_channels]*4 + [out_channels]
kernel = MLP(layers, nonlinearity='rational', aug='none_1d').to(device)
model = GreenNet1D(n=n, kernel=kernel, device=device, p=p)
model.rand_sub()

opt_adam = torch.optim.Adam(kernel.parameters(), lr=1e-3)

model.kernel.train()

for i in trange(warmup_iters):
    if p < 1:
        model.eval_K_sub()
        u_ = model.sub_kint(f)
        # calc loss 
        loss = rl2_error(u_.T, u[model.sub].T)
    else:
        model.eval_K()
        u_ = model.full_kint(f)
        # calc loss 
        loss = rl2_error(u_.T, u.T)
    
    opt_adam.zero_grad()
    loss.backward() # use the l2 relative loss
    opt_adam.step()

if cuda:
    torch.cuda.synchronize()

times = []
keval_times = []
kint_times = []

for i in trange(iters):
    if cuda:
        torch.cuda.synchronize()
    start_time = time.time()

    if p < 1:
        keval_start_time = time.time()
        model.eval_K_sub()
        keval_end_time = time.time()
        u_ = model.sub_kint(f)
        kint_end_time = time.time()
        # calc loss 
        loss = rl2_error(u_.T, u[model.sub].T)
    else:
        keval_start_time = time.time()
        model.eval_K()
        keval_end_time = time.time()
        u_ = model.full_kint(f)
        kint_end_time = time.time()
        # calc loss 
        loss = rl2_error(u_.T, u.T)
    
    opt_adam.zero_grad()
    loss.backward() # use the l2 relative loss
    opt_adam.step()

    if cuda:
        torch.cuda.synchronize()    
    end_time = time.time()

    elapsed = end_time - start_time
    times.append(elapsed)

    keval_elapsed = keval_end_time - keval_start_time
    keval_times.append(keval_elapsed)

    kint_elapsed = kint_end_time - keval_end_time
    kint_times.append(kint_elapsed)

avg_time = sum(times) / iters
keval_avg_time = sum(keval_times) / iters
kint_avg_time = sum(kint_times) / iters

print("full train : ", avg_time)
print("keval train : ", keval_avg_time)
print("kint train : ", kint_avg_time)

100%|██████████| 5/5 [00:00<00:00, 13.88it/s]
 11%|█         | 106/1000 [00:06<00:54, 16.31it/s]


KeyboardInterrupt: 

Speed Measure for GreenMGNet 1D