In [1]:
import torch
from torch.nn import Sequential, Linear, ReLU, Sigmoid, Tanh, Dropout, Upsample
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import BatchNorm
import numpy as np
from torch_geometric.data import Data
from torch.autograd import Variable
from torch.utils.data import DataLoader

from torch.distributions import normal, kl

from tqdm import tqdm

from sklearn.model_selection import KFold
import pandas as pd
from MatrixVectorizer import MatrixVectorizer


In [2]:
# set global variables
N_SUBJECTS = 167

N_LR_NODES = 160

N_HR_NODES = 268

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

N_LR_NODES_F = int(N_LR_NODES * (N_LR_NODES-1) / 2)
N_HR_NODES_F = int(N_HR_NODES * (N_HR_NODES-1) / 2)

# Model Layers

In [3]:
class SheafConvLayer(nn.Module):
    def __init__(self, n_nodes, d, f_in, f_out=None):
        super().__init__()
        self.d = d
        self.n_nodes = n_nodes
        self.f_in = f_in
        self.f_out = f_out
        # random init weight matrices
        if f_out is None:
            f_out = f_in 
        self.weight1 = nn.Parameter(torch.randn((d, d), device=DEVICE))
        self.weight2 = nn.Parameter(torch.randn((f_in, f_out), device=DEVICE))
        self.edge_weights = nn.Parameter(torch.randn((n_nodes, n_nodes, d, 2*d), device=DEVICE))


    def forward(self, X, adj):
        kron_prod = torch.kron(torch.eye(self.n_nodes).to(DEVICE), self.weight1)
        L = self.sheaf_laplacian(X, adj)
        if self.f_out is None:
            return X - F.elu(L @ kron_prod @ X @ self.weight2), L
        else:
            return F.elu(L @ kron_prod @ X @ self.weight2), L


    def sheaf_laplacian(self, X, adj, epsilon=1e-6):
        X_reshaped = X.reshape(self.n_nodes, self.d, -1)
        idx_pairs = torch.cartesian_prod(torch.arange(self.n_nodes), torch.arange(self.n_nodes))
        all_stacked_features = X_reshaped[idx_pairs].reshape(self.n_nodes, self.n_nodes, 2*self.d, -1).to(DEVICE)
        lin_trans = F.elu(torch.matmul(self.edge_weights, all_stacked_features))
        inner_transpose = torch.transpose(lin_trans, -1, -2)
        L_v = -1 * torch.matmul(lin_trans, torch.transpose(inner_transpose, 0, 1))
        # row_cond = torch.isclose(torch.sum(adj, dim=1), torch.zeros_like(torch.sum(adj, dim=1)))
        # col_cond = torch.isclose(torch.sum(adj, dim=0), torch.zeros_like(torch.sum(adj, dim=0)))
        adj_row_weights = adj / (torch.sum(adj, dim=1)[:, None] + epsilon)
        adj_col_weights = adj / (torch.sum(adj, dim=0)[:, None] + epsilon)
        # adj_col_weights = torch.where(col_cond[None, :], 0., adj / torch.sum(adj, dim=0)[None, :])
        adj_weights = torch.maximum(adj_row_weights * adj_col_weights, torch.zeros_like(adj_row_weights))

        adj_diag_weights = adj_row_weights ** 2
        diag_blocks = torch.sum(adj_diag_weights[:, :, None, None] * torch.matmul(lin_trans, inner_transpose), dim=1)
        L_v[range(self.n_nodes), range(self.n_nodes)] = diag_blocks
        return L_v.reshape(-1, self.n_nodes * self.d)
        ### NOTE IGNORE MATRIX NORMALISATION FOR NOW #####
        # inv_root_diag_blocks = torch.pow(diag_blocks+epsilon, -1/2)
        # normalise_mat = torch.block_diag(*inv_root_diag_blocks)

        # return normalise_mat @ L_v.view(-1, self.n_nodes * self.d) @ normalise_mat
        ################################################

In [4]:
class AdjacencyDimChanger(nn.Module):
    def __init__(self, new_n, old_n, old_f, d):
        super().__init__()
        self.new_n = new_n
        self.old_n = old_n
        self.d = d
        self.sheafconv = SheafConvLayer(old_n, d, old_f, new_n)
        self.layernorm = nn.LayerNorm([d, old_n]).to(DEVICE)

    def forward(self, X, adj):

        adj = adj - torch.diag_embed(torch.diagonal(adj, 0)).to(DEVICE) + torch.eye(adj.shape[0]).to(DEVICE)  # add self connections
        x, L = self.sheafconv(X, adj)

        x = x.reshape(self.old_n, self.d, self.new_n)
        x = torch.transpose(x, 0, -1)
        x = self.layernorm(x)
        
        x_mean = x.mean(dim=-1)

        L_mean = L.reshape(self.old_n, self.old_n, self.d, self.d).max(dim=0)[0].mean(dim=0) # aggregate by eigenvalues of each n by n mat?
        adj_new = torch.matmul(x_mean, L_mean)
        adj_new = torch.matmul(adj_new, x_mean.T)
        adj_new_T = torch.t(adj_new)
        adj_new = F.tanh(F.relu(((adj_new + adj_new_T) / 2))) # becomes a new f by new f adj1

        return x.reshape(self.new_n*self.d, -1), adj_new

In [5]:
class AdjacencyChangerUp(nn.Module):

    def __init__(self, d, f_in):
        super().__init__()
        self.d = d

        self.adjdim_changer1 = AdjacencyDimChanger(216, N_LR_NODES, f_in, d)
        self.adjdim_changer2 = AdjacencyDimChanger(256, 216, N_LR_NODES, d)
        self.adjdim_changer3 = AdjacencyDimChanger(N_HR_NODES, 256, 216, d)

        # self.adjdim_changer1 = AdjacencyDimChanger(256, N_LR_NODES, f_in, d)
        # self.adjdim_changer2 = AdjacencyDimChanger(N_HR_NODES, 256, N_LR_NODES, d)

        
    def forward(self, X, adj):
        x1, adj1 = self.adjdim_changer1(X, adj)
        x2, adj2 = self.adjdim_changer2(x1, adj1)
        # return [adj, adj1, adj2]

        x3, adj3 = self.adjdim_changer3(x2, adj2)
        return [adj, adj1, adj2, adj3]
        

In [6]:
class AdjacencyChangerDown(nn.Module):

    def __init__(self, d, f_in):
        super().__init__()
        self.d = d

        self.adjdim_changer1 = AdjacencyDimChanger(256, N_HR_NODES, f_in, d)
        self.adjdim_changer2  = AdjacencyDimChanger(216, 256, N_HR_NODES, d)
        self.adjdim_changer3 = AdjacencyDimChanger(N_LR_NODES, 216, 256, d)

        # self.adjdim_changer1 = AdjacencyDimChanger(256, N_HR_NODES, f_in, d)
        # self.adjdim_changer2  = AdjacencyDimChanger(N_LR_NODES, 256, N_HR_NODES, d)

        
    def forward(self, X, adj):
        x1, adj1 = self.adjdim_changer1(X, adj)
        x2, adj2 = self.adjdim_changer2(x1, adj1)
        # return [adj, adj1, adj2]


        x3, adj3 = self.adjdim_changer3(x2, adj2)
        return [adj, adj1, adj2, adj3]

In [7]:
def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False

def unfreeze_model(model):
    for param in model.parameters():
        param.requires_grad = True

In [7]:
import numpy as np
import networkx as nx

def eigen_centrality(data):
    # ROI is the number of brain regions (i.e.,35 in our case)
    ROI = 160

    topology_eigen = []

    G = nx.from_numpy_array(np.absolute(data))
    U = G.to_undirected()

    # A = to_2d(data)
    np.fill_diagonal(data, 0)

    # create a graph frL2
    # # compute egeinvector centrality and transform the output to vector
    ec = nx.eigenvector_centrality_numpy(U)
    
    eigenvector_centrality = np.array([ec[g] for g in U])



    topology_eigen.append(eigenvector_centrality)  # 2

    return topology_eigen

def pearson_coor(input, target, epsilon=1e-7):
    vx = input - torch.mean(input, dim=(1, 2))[:, None, None]
    vy = target - torch.mean(target, dim=(1, 2))[:, None, None]
    cost = torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx ** 2)+epsilon) * torch.sqrt(torch.sum(vy ** 2)+epsilon)+epsilon)
    return cost

def GT_loss(target, predicted):

    # l1_loss
    l1_loss = torch.nn.L1Loss()
    # loss_pix2pix = l1_loss(target, predicted)

    # topological_loss
    target_n = target.detach().cpu().clone().numpy()
    predicted_n = predicted.detach().cpu().clone().numpy()
    torch.cuda.empty_cache()

    topo_loss = []
    

    for i in range(len(target_n)):

        cur_target = target_n[i]
        cur_predicted = predicted_n[i]

        target_t = eigen_centrality(cur_target)
        real_topology = torch.tensor(target_t[0])
        predicted_t = eigen_centrality(cur_predicted)
        fake_topology = torch.tensor(predicted_t[0])
        topo_loss.append(l1_loss(real_topology, fake_topology))

    topo_loss = torch.sum(torch.stack(topo_loss))

    pc_loss = pearson_coor(target, predicted).to(DEVICE)
    torch.cuda.empty_cache()

    # G_loss = loss_pix2pix + (1 - pc_loss) + topo_loss
    G_loss = (1 - pc_loss) + topo_loss


    return G_loss

In [8]:
def loss_calc(adj_ls, opp_adj_ls):
    total_loss = torch.Tensor([0]).to(DEVICE)
    mse_loss_fn = torch.nn.MSELoss()
    n = len(adj_ls)

    mse_loss = torch.Tensor([0]).to(DEVICE)
    for i, (adj, opp_adj) in enumerate(zip(adj_ls[::-1], opp_adj_ls)):
        weight = 2 * (i + 1) / (n * (1 + n))
        mse_loss = mse_loss + mse_loss_fn(adj, opp_adj) * weight

    gt_loss = torch.Tensor([0]).to(DEVICE)
    for i, (adj, opp_adj) in enumerate(zip(adj_ls[::-1], opp_adj_ls)):

        ### NOTE TEMPORARY MEASURE BECAUSE THEY TAKE IN (BATCHSIZE, xx, xx) shape ####
        temp_adj = adj.reshape(1, *adj.shape)
        temp_opp_adj = opp_adj.reshape(1, *opp_adj.shape)
        ##########################################################
        gt_loss = gt_loss + GT_loss(temp_adj, temp_opp_adj)

    gt_loss = gt_loss / n
        
    return mse_loss + gt_loss

# Training

In [8]:
from data_preparation import load_data_tensor

lr_train, lr_test, hr_train = load_data_tensor("dgl-icl")

In [11]:
lr_X_dim1 = torch.load('model_autoencoder/encode_lr_1.pt')
lr_X_dim2 = torch.load('model_autoencoder/encode_lr_2.pt')
hr_X_dim1 = torch.load('model_autoencoder/encode_hr_1.pt')
hr_X_dim2 = torch.load('model_autoencoder/encode_hr_2.pt')


lr_X_all = torch.empty((167, 320, 32))
for i in range(len(lr_X_dim1)):
    a, b = lr_X_dim1[i], lr_X_dim2[i]
    lr_X_all[i] = torch.cat([a, b], dim=-1).view(-1, a.shape[-1])

hr_X_all = torch.empty((167, 536, 32))
for i in range(len(hr_X_dim1)):
    a, b = hr_X_dim1[i], hr_X_dim2[i]
    hr_X_all[i] = torch.cat([a, b], dim=-1).view(-1, a.shape[-1])

In [12]:
trainloader = DataLoader(list(zip(lr_X_all, lr_train, hr_X_all, hr_train)), shuffle=True, batch_size=8)


up_changer = AdjacencyChangerUp(d=2,f_in=32).to(DEVICE)
down_changer = AdjacencyChangerDown(d=2,f_in=32).to(DEVICE)

up_optimizer = torch.optim.AdamW(up_changer.parameters(), lr=0.001, betas=(0.5, 0.999))
down_optimizer = torch.optim.AdamW(down_changer.parameters(), lr=0.001, betas=(0.5, 0.999))

sum(p.numel() for model in [up_changer, down_changer] for p in model.parameters())


2792776

In [13]:
def train(epochs, up_changer, down_changer, trainloader, up_optimizer, down_optimizer):

    up_changer.train()
    down_changer.train()
    for epoch in range(epochs):

        up_losses = []
        down_losses = []

        for X_lr, adj_lr, X_hr, adj_hr in tqdm(trainloader):

            freeze_model(up_changer)
            unfreeze_model(down_changer)
        
            down_optimizer.zero_grad()
            up_optimizer.zero_grad()

            down_batch_loss = []

            for i in range(len(X_lr)):

                up_adj_ls = up_changer(X_lr[i].to(DEVICE), adj_lr[i].to(DEVICE))
                torch.cuda.empty_cache()
                down_adj_ls = down_changer(X_hr[i].to(DEVICE), adj_hr[i].to(DEVICE))
                torch.cuda.empty_cache()

                down_batch_loss.append(loss_calc(down_adj_ls, up_adj_ls))

            down_loss = torch.mean(torch.stack(down_batch_loss))
            down_loss.backward()
            down_optimizer.step()

            print(f'Down Loss = {down_loss.detach().item()}')
            down_losses.append(down_loss.detach().item())
            del down_loss
            del down_batch_loss
            torch.cuda.empty_cache()

            unfreeze_model(up_changer)
            freeze_model(down_changer)
        
            down_optimizer.zero_grad()
            up_optimizer.zero_grad()

            up_batch_loss = []


            for i in range(len(X_lr)):

                up_adj_ls = up_changer(X_lr[i].to(DEVICE), adj_lr[i].to(DEVICE))
                torch.cuda.empty_cache()
                down_adj_ls = down_changer(X_hr[i].to(DEVICE), adj_hr[i].to(DEVICE))
                torch.cuda.empty_cache()
            
                up_batch_loss.append(loss_calc(up_adj_ls, down_adj_ls))

            up_loss = torch.mean(torch.stack(up_batch_loss))
            up_loss.backward()
            up_optimizer.step()

            print(f'Up Loss = {up_loss.detach().item()}')
            up_losses.append(up_loss.detach().item())
            del up_loss
            del up_batch_loss
            torch.cuda.empty_cache()



        epoch_up_loss = np.mean(up_losses)
        epoch_down_loss = np.mean(down_losses)

        print(f'epoch {epoch}: down loss = {epoch_down_loss}, up loss = {epoch_up_loss}')

    return up_changer, down_changer


In [14]:
from evaluation_fn import evaluate_predictions


def validation(up_changer, val_X_lr, val_adj_lr, val_adj_hr):
    print('begin validation')
    up_changer.eval()

    all_predictions = torch.empty((len(val_X_lr), N_HR_NODES, N_HR_NODES), requires_grad=False).cpu()

    for i in tqdm(range(len(val_X_lr))):
        all_predictions[i] = up_changer(val_X_lr[i].to(DEVICE), val_adj_lr[i].to(DEVICE))[-1].detach()
        torch.cuda.empty_cache()

    return evaluate_predictions(all_predictions, val_adj_hr)


In [15]:
def cross_validate(model, epochs, n_fold, X_lr, adj_lr, X_hr, adj_hr, d=2, f=32):
    kf = KFold(n_fold, shuffle=True, random_state=99)
    runs_results = []
    for train_idx, val_idx in kf.split(X_lr):
        train_X_lr, val_X_lr = X_lr[train_idx], X_lr[val_idx]
        train_adj_lr, val_adj_lr = adj_lr[train_idx], adj_lr[val_idx]
        train_X_hr = X_hr[train_idx]
        train_adj_hr, val_adj_hr = adj_hr[train_idx], adj_hr[val_idx]

        trainloader = DataLoader(list(zip(train_X_lr, train_adj_lr, train_X_hr, train_adj_hr)), shuffle=True, batch_size=8)

        up_changer = AdjacencyChangerUp(d=d,f_in=f).to(DEVICE)
        down_changer = AdjacencyChangerDown(d=d,f_in=f).to(DEVICE)

        up_optimizer = torch.optim.AdamW(up_changer.parameters(), lr=0.001, betas=(0.5, 0.999))
        down_optimizer = torch.optim.AdamW(down_changer.parameters(), lr=0.001, betas=(0.5, 0.999))    
        
        up_changer, down_changer = train(epochs, up_changer, down_changer, trainloader, up_optimizer, down_optimizer)
        model.append(up_changer)
        val_metrics = validation(up_changer, val_X_lr, val_adj_lr, val_adj_hr)
        runs_results.append(val_metrics)

    return runs_results


In [None]:
models=[]
cv_results = cross_validate(models, 10, 3, lr_X_all, lr_train, hr_X_all, hr_train)

In [16]:
# train on entire dataset
up_changer, down_changer = train(20, up_changer, down_changer, trainloader, up_optimizer, down_optimizer)

  0%|          | 0/21 [00:00<?, ?it/s]

Down Loss = 2.2756166458129883


  5%|▍         | 1/21 [00:24<08:11, 24.59s/it]

Up Loss = 2.226686716079712
Down Loss = 2.2351632118225098


 10%|▉         | 2/21 [00:48<07:42, 24.34s/it]

Up Loss = 2.1783528327941895
Down Loss = 2.12856388092041


 14%|█▍        | 3/21 [01:13<07:23, 24.64s/it]

Up Loss = 2.057070732116699
Down Loss = 2.0111918449401855


 19%|█▉        | 4/21 [01:41<07:20, 25.93s/it]

Up Loss = 1.9125494956970215
Down Loss = 1.8966878652572632


 24%|██▍       | 5/21 [02:11<07:16, 27.25s/it]

Up Loss = 1.8017646074295044
Down Loss = 1.6918470859527588


 29%|██▊       | 6/21 [02:41<07:03, 28.21s/it]

Up Loss = 1.6624653339385986
Down Loss = 1.7208101749420166


 33%|███▎      | 7/21 [03:12<06:47, 29.08s/it]

Up Loss = 1.6564383506774902
Down Loss = 1.6837055683135986


 38%|███▊      | 8/21 [03:45<06:35, 30.46s/it]

Up Loss = 1.6499943733215332
Down Loss = 1.619810938835144


 43%|████▎     | 9/21 [04:21<06:27, 32.30s/it]

Up Loss = 1.5877271890640259
Down Loss = 1.6386438608169556


 48%|████▊     | 10/21 [04:59<06:12, 33.86s/it]

Up Loss = 1.616821050643921
Down Loss = 1.6018133163452148


 52%|█████▏    | 11/21 [05:36<05:49, 34.92s/it]

Up Loss = 1.5823755264282227
Down Loss = 1.5638483762741089


 57%|█████▋    | 12/21 [06:14<05:21, 35.78s/it]

Up Loss = 1.552931308746338
Down Loss = 1.550428867340088


 62%|██████▏   | 13/21 [06:51<04:50, 36.26s/it]

Up Loss = 1.5387907028198242
Down Loss = 1.5602506399154663


 67%|██████▋   | 14/21 [07:29<04:17, 36.78s/it]

Up Loss = 1.5354368686676025
Down Loss = 1.5590548515319824


 71%|███████▏  | 15/21 [08:07<03:41, 36.99s/it]

Up Loss = 1.5477080345153809
Down Loss = 1.564673900604248


 76%|███████▌  | 16/21 [08:44<03:05, 37.16s/it]

Up Loss = 1.5540921688079834
Down Loss = 1.5613884925842285


 81%|████████  | 17/21 [09:22<02:29, 37.27s/it]

Up Loss = 1.5441441535949707
Down Loss = 1.5463401079177856


 86%|████████▌ | 18/21 [10:00<01:52, 37.48s/it]

Up Loss = 1.5355116128921509
Down Loss = 1.5551259517669678


 90%|█████████ | 19/21 [10:38<01:15, 37.72s/it]

Up Loss = 1.5268666744232178
Down Loss = 1.5521488189697266


 95%|█████████▌| 20/21 [11:17<00:37, 37.95s/it]

Up Loss = 1.5209991931915283
Down Loss = 1.5471069812774658


100%|██████████| 21/21 [11:49<00:00, 33.80s/it]


Up Loss = 1.539756417274475
epoch 0: down loss = 1.717343875340053, up loss = 1.6823087306249709


  0%|          | 0/21 [00:00<?, ?it/s]

Down Loss = 1.5225045680999756


  5%|▍         | 1/21 [00:37<12:39, 37.99s/it]

Up Loss = 1.523155689239502
Down Loss = 1.5475094318389893


 10%|▉         | 2/21 [01:16<12:07, 38.27s/it]

Up Loss = 1.545259952545166
Down Loss = 1.5504640340805054


 14%|█▍        | 3/21 [01:55<11:34, 38.57s/it]

Up Loss = 1.5241985321044922
Down Loss = 1.5309808254241943


 19%|█▉        | 4/21 [02:33<10:51, 38.30s/it]

Up Loss = 1.5181193351745605
Down Loss = 1.5186502933502197


 24%|██▍       | 5/21 [03:11<10:14, 38.43s/it]

Up Loss = 1.5059785842895508
Down Loss = 1.5577337741851807


 29%|██▊       | 6/21 [03:49<09:34, 38.30s/it]

Up Loss = 1.5453550815582275
Down Loss = 1.536655306816101


 33%|███▎      | 7/21 [04:28<08:57, 38.39s/it]

Up Loss = 1.5185999870300293
Down Loss = 1.5202295780181885


 38%|███▊      | 8/21 [05:07<08:21, 38.56s/it]

Up Loss = 1.5104620456695557
Down Loss = 1.5240445137023926


 43%|████▎     | 9/21 [05:46<07:45, 38.78s/it]

Up Loss = 1.506600260734558
Down Loss = 1.537268877029419


 48%|████▊     | 10/21 [06:25<07:07, 38.83s/it]

Up Loss = 1.5163915157318115
Down Loss = 1.5319091081619263


 52%|█████▏    | 11/21 [07:04<06:29, 38.92s/it]

Up Loss = 1.5174965858459473
Down Loss = 1.602562427520752


 57%|█████▋    | 12/21 [07:44<05:51, 39.05s/it]

Up Loss = 1.5835403203964233
Down Loss = 1.5263160467147827


 62%|██████▏   | 13/21 [08:23<05:13, 39.17s/it]

Up Loss = 1.5066897869110107
Down Loss = 1.5479578971862793


 67%|██████▋   | 14/21 [09:03<04:35, 39.31s/it]

Up Loss = 1.5329943895339966
Down Loss = 1.5068058967590332


 71%|███████▏  | 15/21 [09:43<03:56, 39.48s/it]

Up Loss = 1.4897217750549316
Down Loss = 1.5190694332122803


 76%|███████▌  | 16/21 [10:22<03:16, 39.32s/it]

Up Loss = 1.4967507123947144
Down Loss = 1.5335288047790527


 81%|████████  | 17/21 [11:02<02:38, 39.53s/it]

Up Loss = 1.4937047958374023
Down Loss = 1.540257215499878


 86%|████████▌ | 18/21 [11:41<01:58, 39.53s/it]

Up Loss = 1.5130102634429932
Down Loss = 1.499385118484497


 90%|█████████ | 19/21 [12:21<01:18, 39.50s/it]

Up Loss = 1.4769997596740723
Down Loss = 1.5095932483673096


 95%|█████████▌| 20/21 [13:00<00:39, 39.50s/it]

Up Loss = 1.4893782138824463
Down Loss = 1.5131152868270874


100%|██████████| 21/21 [13:35<00:00, 38.83s/it]


Up Loss = 1.4762184619903564
epoch 1: down loss = 1.5322162707646687, up loss = 1.5138393356686546


  0%|          | 0/21 [00:00<?, ?it/s]

Down Loss = 1.4856350421905518


  5%|▍         | 1/21 [00:38<12:59, 38.96s/it]

Up Loss = 1.464379072189331
Down Loss = 1.501955270767212


 10%|▉         | 2/21 [01:17<12:12, 38.57s/it]

Up Loss = 1.4809925556182861
Down Loss = 1.4831898212432861


 14%|█▍        | 3/21 [01:55<11:34, 38.59s/it]

Up Loss = 1.4602960348129272
Down Loss = 1.510953426361084


 19%|█▉        | 4/21 [02:34<10:56, 38.59s/it]

Up Loss = 1.4890631437301636
Down Loss = 1.4653671979904175


 24%|██▍       | 5/21 [03:13<10:17, 38.60s/it]

Up Loss = 1.4418442249298096
Down Loss = 1.4817166328430176


 29%|██▊       | 6/21 [03:51<09:38, 38.56s/it]

Up Loss = 1.4596238136291504
Down Loss = 1.5027470588684082


 33%|███▎      | 7/21 [04:30<08:59, 38.54s/it]

Up Loss = 1.4875948429107666
Down Loss = 1.4823949337005615


 38%|███▊      | 8/21 [05:09<08:22, 38.68s/it]

Up Loss = 1.4622509479522705
Down Loss = 1.4981945753097534


 43%|████▎     | 9/21 [05:47<07:42, 38.58s/it]

Up Loss = 1.4732885360717773
Down Loss = 1.4632277488708496


 48%|████▊     | 10/21 [06:26<07:05, 38.67s/it]

Up Loss = 1.4466288089752197
Down Loss = 1.488065481185913


 52%|█████▏    | 11/21 [07:05<06:27, 38.71s/it]

Up Loss = 1.4608824253082275
Down Loss = 1.505277156829834


 57%|█████▋    | 12/21 [07:44<05:49, 38.79s/it]

Up Loss = 1.4850304126739502
Down Loss = 1.4738582372665405


 62%|██████▏   | 13/21 [08:22<05:10, 38.75s/it]

Up Loss = 1.4510889053344727
Down Loss = 1.4536163806915283


 67%|██████▋   | 14/21 [09:01<04:31, 38.80s/it]

Up Loss = 1.4215033054351807
Down Loss = 1.477665901184082


 71%|███████▏  | 15/21 [09:40<03:53, 38.92s/it]

Up Loss = 1.4546921253204346
Down Loss = 1.5019264221191406


 76%|███████▌  | 16/21 [10:19<03:14, 38.89s/it]

Up Loss = 1.4809917211532593
Down Loss = 1.474031925201416


 81%|████████  | 17/21 [10:58<02:35, 38.85s/it]

Up Loss = 1.4427125453948975
Down Loss = 1.4773342609405518


 86%|████████▌ | 18/21 [11:37<01:56, 38.91s/it]

Up Loss = 1.4499300718307495
Down Loss = 1.4851465225219727


 90%|█████████ | 19/21 [12:16<01:18, 39.02s/it]

Up Loss = 1.466301441192627
Down Loss = 1.4642947912216187


 95%|█████████▌| 20/21 [12:55<00:38, 38.99s/it]

Up Loss = 1.4416306018829346
Down Loss = 1.478352427482605


100%|██████████| 21/21 [13:29<00:00, 38.55s/it]


Up Loss = 1.4556647539138794
epoch 2: down loss = 1.4835691054662068, up loss = 1.460780490012396


  0%|          | 0/21 [00:00<?, ?it/s]

Down Loss = 1.4630050659179688


  5%|▍         | 1/21 [00:38<12:57, 38.85s/it]

Up Loss = 1.4501430988311768
Down Loss = 1.4551904201507568


 10%|▉         | 2/21 [01:17<12:16, 38.76s/it]

Up Loss = 1.4350614547729492
Down Loss = 1.4240357875823975


 14%|█▍        | 3/21 [01:56<11:41, 38.95s/it]

Up Loss = 1.3901616334915161
Down Loss = 1.4283757209777832


 19%|█▉        | 4/21 [02:36<11:04, 39.09s/it]

Up Loss = 1.407252311706543
Down Loss = 1.4000236988067627


 24%|██▍       | 5/21 [03:15<10:25, 39.08s/it]

Up Loss = 1.3622934818267822
Down Loss = 1.44087553024292


 29%|██▊       | 6/21 [03:53<09:45, 39.01s/it]

Up Loss = 1.412463665008545
Down Loss = 1.4511604309082031


 33%|███▎      | 7/21 [04:32<09:05, 38.99s/it]

Up Loss = 1.4251774549484253
Down Loss = 1.4113422632217407


 38%|███▊      | 8/21 [05:11<08:26, 38.93s/it]

Up Loss = 1.389689564704895
Down Loss = 1.4827102422714233


 43%|████▎     | 9/21 [05:50<07:46, 38.89s/it]

Up Loss = 1.459458351135254
Down Loss = 1.4161314964294434


 48%|████▊     | 10/21 [06:29<07:08, 38.96s/it]

Up Loss = 1.3919994831085205
Down Loss = 1.4287827014923096


 52%|█████▏    | 11/21 [07:09<06:31, 39.13s/it]

Up Loss = 1.4083565473556519
Down Loss = 1.4059820175170898


 57%|█████▋    | 12/21 [07:48<05:52, 39.18s/it]

Up Loss = 1.3774757385253906
Down Loss = 1.4086244106292725


 62%|██████▏   | 13/21 [08:27<05:12, 39.09s/it]

Up Loss = 1.3840410709381104
Down Loss = 1.4117566347122192


 67%|██████▋   | 14/21 [09:06<04:34, 39.15s/it]

Up Loss = 1.3844674825668335
Down Loss = 1.4105322360992432


 71%|███████▏  | 15/21 [09:46<03:55, 39.22s/it]

Up Loss = 1.392310619354248
Down Loss = 1.3979603052139282


 76%|███████▌  | 16/21 [10:26<03:17, 39.54s/it]

Up Loss = 1.3692049980163574
Down Loss = 1.435012698173523


 81%|████████  | 17/21 [11:06<02:39, 39.85s/it]

Up Loss = 1.4148883819580078
Down Loss = 1.4104048013687134


 86%|████████▌ | 18/21 [11:47<01:59, 39.95s/it]

Up Loss = 1.3862988948822021
Down Loss = 1.4210386276245117


 90%|█████████ | 19/21 [12:27<01:20, 40.09s/it]

Up Loss = 1.402000904083252
Down Loss = 1.4202158451080322


 95%|█████████▌| 20/21 [13:07<00:40, 40.06s/it]

Up Loss = 1.400352120399475
Down Loss = 1.3974131345748901


100%|██████████| 21/21 [13:42<00:00, 39.18s/it]


Up Loss = 1.3791779279708862
epoch 3: down loss = 1.424789241382054, up loss = 1.4010607231230963


  0%|          | 0/21 [00:00<?, ?it/s]

Down Loss = 1.4071762561798096


  5%|▍         | 1/21 [00:40<13:32, 40.63s/it]

Up Loss = 1.39235520362854
Down Loss = 1.361856460571289


 10%|▉         | 2/21 [01:20<12:42, 40.13s/it]

Up Loss = 1.3431087732315063
Down Loss = 1.3605992794036865


 14%|█▍        | 3/21 [02:00<12:03, 40.20s/it]

Up Loss = 1.3394349813461304
Down Loss = 1.3437684774398804


 19%|█▉        | 4/21 [02:40<11:23, 40.23s/it]

Up Loss = 1.322057843208313
Down Loss = 1.3481465578079224


 24%|██▍       | 5/21 [03:21<10:46, 40.41s/it]

Up Loss = 1.3199152946472168
Down Loss = 1.36899995803833


 29%|██▊       | 6/21 [04:01<10:02, 40.18s/it]

Up Loss = 1.3399372100830078
Down Loss = 1.351813793182373


 33%|███▎      | 7/21 [04:42<09:24, 40.31s/it]

Up Loss = 1.33025062084198
Down Loss = 1.3795541524887085


 38%|███▊      | 8/21 [05:22<08:42, 40.23s/it]

Up Loss = 1.3587422370910645
Down Loss = 1.3450219631195068


 43%|████▎     | 9/21 [06:02<08:03, 40.32s/it]

Up Loss = 1.325923204421997
Down Loss = 1.3608369827270508


 48%|████▊     | 10/21 [06:41<07:18, 39.88s/it]

Up Loss = 1.341756820678711
Down Loss = 1.386026382446289


 52%|█████▏    | 11/21 [07:21<06:38, 39.80s/it]

Up Loss = 1.3626958131790161
Down Loss = 1.3678011894226074


 57%|█████▋    | 12/21 [08:00<05:56, 39.62s/it]

Up Loss = 1.336987018585205
Down Loss = 1.350193977355957


 62%|██████▏   | 13/21 [08:39<05:16, 39.52s/it]

Up Loss = 1.3216636180877686
Down Loss = 1.3454973697662354


 67%|██████▋   | 14/21 [09:18<04:35, 39.36s/it]

Up Loss = 1.3202767372131348
Down Loss = 1.3423197269439697


 71%|███████▏  | 15/21 [09:57<03:55, 39.21s/it]

Up Loss = 1.3203394412994385
Down Loss = 1.3706612586975098


 76%|███████▌  | 16/21 [10:36<03:15, 39.04s/it]

Up Loss = 1.350781798362732
Down Loss = 1.3652044534683228


 81%|████████  | 17/21 [11:15<02:36, 39.04s/it]

Up Loss = 1.3481554985046387
Down Loss = 1.3114023208618164


 86%|████████▌ | 18/21 [11:54<01:57, 39.21s/it]

Up Loss = 1.2902119159698486
Down Loss = 1.3528153896331787


 90%|█████████ | 19/21 [12:34<01:18, 39.24s/it]

Up Loss = 1.3264663219451904
Down Loss = 1.3770575523376465


 95%|█████████▌| 20/21 [13:13<00:39, 39.23s/it]

Up Loss = 1.3570351600646973
Down Loss = 1.3471766710281372


100%|██████████| 21/21 [13:47<00:00, 39.41s/it]


Up Loss = 1.32389497756958
epoch 4: down loss = 1.3592347701390584, up loss = 1.3367614519028437


  0%|          | 0/21 [00:00<?, ?it/s]

Down Loss = 1.3159964084625244


  5%|▍         | 1/21 [00:39<13:07, 39.38s/it]

Up Loss = 1.2926870584487915
Down Loss = 1.2972197532653809


 10%|▉         | 2/21 [01:18<12:29, 39.46s/it]

Up Loss = 1.2747862339019775
Down Loss = 1.3142037391662598


 14%|█▍        | 3/21 [01:57<11:47, 39.29s/it]

Up Loss = 1.2868120670318604
Down Loss = 1.3238720893859863


 19%|█▉        | 4/21 [02:36<11:05, 39.14s/it]

Up Loss = 1.314470648765564
Down Loss = 1.3184826374053955


 24%|██▍       | 5/21 [03:15<10:25, 39.07s/it]

Up Loss = 1.3032987117767334
Down Loss = 1.3239918947219849


 29%|██▊       | 6/21 [03:55<09:46, 39.12s/it]

Up Loss = 1.2977969646453857
Down Loss = 1.3122233152389526


 33%|███▎      | 7/21 [04:34<09:08, 39.20s/it]

Up Loss = 1.2915563583374023
Down Loss = 1.3138405084609985


 38%|███▊      | 8/21 [05:13<08:29, 39.16s/it]

Up Loss = 1.2911951541900635
Down Loss = 1.2525358200073242


 43%|████▎     | 9/21 [05:52<07:50, 39.21s/it]

Up Loss = 1.2321455478668213
Down Loss = 1.2690129280090332


 48%|████▊     | 10/21 [06:32<07:12, 39.31s/it]

Up Loss = 1.2473047971725464
Down Loss = 1.2841148376464844


 52%|█████▏    | 11/21 [07:11<06:33, 39.36s/it]

Up Loss = 1.2655360698699951
Down Loss = 1.2588200569152832


 57%|█████▋    | 12/21 [07:51<05:53, 39.30s/it]

Up Loss = 1.242206335067749
Down Loss = 1.2721809148788452


 62%|██████▏   | 13/21 [08:30<05:13, 39.22s/it]

Up Loss = 1.2542961835861206
Down Loss = 1.2699618339538574


 67%|██████▋   | 14/21 [09:09<04:34, 39.16s/it]

Up Loss = 1.2552943229675293
Down Loss = 1.3390097618103027


 71%|███████▏  | 15/21 [09:48<03:54, 39.13s/it]

Up Loss = 1.3168116807937622
Down Loss = 1.3000621795654297


 76%|███████▌  | 16/21 [10:27<03:15, 39.12s/it]

Up Loss = 1.2841017246246338
Down Loss = 1.2903056144714355


 81%|████████  | 17/21 [11:06<02:36, 39.06s/it]

Up Loss = 1.2735627889633179
Down Loss = 1.2990803718566895


 86%|████████▌ | 18/21 [11:45<01:57, 39.07s/it]

Up Loss = 1.2845304012298584
Down Loss = 1.3023076057434082


 90%|█████████ | 19/21 [12:24<01:18, 39.10s/it]

Up Loss = 1.2884892225265503
Down Loss = 1.272789478302002


 95%|█████████▌| 20/21 [13:04<00:39, 39.25s/it]

Up Loss = 1.2525112628936768
Down Loss = 1.293738842010498


100%|██████████| 21/21 [13:38<00:00, 38.97s/it]


Up Loss = 1.2774814367294312
epoch 5: down loss = 1.2963690757751465, up loss = 1.2774702367328463


  0%|          | 0/21 [00:00<?, ?it/s]

Down Loss = 1.2314915657043457


  5%|▍         | 1/21 [00:39<13:09, 39.45s/it]

Up Loss = 1.2165501117706299
Down Loss = 1.266463279724121


 10%|▉         | 2/21 [01:18<12:29, 39.47s/it]

Up Loss = 1.251663327217102
Down Loss = 1.2514666318893433


 14%|█▍        | 3/21 [01:58<11:50, 39.49s/it]

Up Loss = 1.2322945594787598
Down Loss = 1.2701244354248047


 19%|█▉        | 4/21 [02:38<11:12, 39.57s/it]

Up Loss = 1.257623553276062
Down Loss = 1.2863240242004395


 24%|██▍       | 5/21 [03:17<10:32, 39.55s/it]

Up Loss = 1.2690765857696533
Down Loss = 1.2366044521331787


 29%|██▊       | 6/21 [03:57<09:55, 39.70s/it]

Up Loss = 1.2194288969039917
Down Loss = 1.2761046886444092


 33%|███▎      | 7/21 [04:37<09:18, 39.88s/it]

Up Loss = 1.2635223865509033
Down Loss = 1.2198874950408936


 38%|███▊      | 8/21 [05:18<08:40, 40.04s/it]

Up Loss = 1.1997511386871338
Down Loss = 1.244458794593811


 43%|████▎     | 9/21 [05:58<08:01, 40.11s/it]

Up Loss = 1.226531982421875
Down Loss = 1.2734547853469849


 48%|████▊     | 10/21 [06:39<07:23, 40.29s/it]

Up Loss = 1.2546851634979248
Down Loss = 1.252144694328308


 52%|█████▏    | 11/21 [07:19<06:42, 40.23s/it]

Up Loss = 1.2313258647918701
Down Loss = 1.2572250366210938


 57%|█████▋    | 12/21 [07:59<06:02, 40.27s/it]

Up Loss = 1.2360882759094238
Down Loss = 1.241638422012329


 62%|██████▏   | 13/21 [08:39<05:21, 40.16s/it]

Up Loss = 1.2165967226028442
Down Loss = 1.225268840789795


 67%|██████▋   | 14/21 [09:20<04:41, 40.25s/it]

Up Loss = 1.202242374420166
Down Loss = 1.238571286201477


 71%|███████▏  | 15/21 [10:00<04:01, 40.23s/it]

Up Loss = 1.2035739421844482
Down Loss = 1.2218091487884521


 76%|███████▌  | 16/21 [10:40<03:21, 40.25s/it]

Up Loss = 1.199913740158081
Down Loss = 1.267441749572754


 81%|████████  | 17/21 [11:20<02:40, 40.14s/it]

Up Loss = 1.239661693572998
Down Loss = 1.23179030418396


 86%|████████▌ | 18/21 [11:59<01:59, 39.95s/it]

Up Loss = 1.2145063877105713
Down Loss = 1.2323309183120728


 90%|█████████ | 19/21 [12:39<01:19, 39.71s/it]

Up Loss = 1.2068959474563599
Down Loss = 1.230269432067871


 95%|█████████▌| 20/21 [13:18<00:39, 39.52s/it]

Up Loss = 1.221089482307434
Down Loss = 1.246753454208374


100%|██████████| 21/21 [13:52<00:00, 39.65s/it]


Up Loss = 1.2348134517669678
epoch 6: down loss = 1.247696354275658, up loss = 1.2284683613550096


  0%|          | 0/21 [00:00<?, ?it/s]

Down Loss = 1.173277497291565


  5%|▍         | 1/21 [00:38<12:59, 38.98s/it]

Up Loss = 1.1731535196304321
Down Loss = 1.2013449668884277


 10%|▉         | 2/21 [01:18<12:25, 39.24s/it]

Up Loss = 1.216294765472412
Down Loss = 1.234299659729004


 14%|█▍        | 3/21 [01:57<11:45, 39.21s/it]

Up Loss = 1.2418338060379028
Down Loss = 1.215313196182251


 19%|█▉        | 4/21 [02:36<11:04, 39.07s/it]

Up Loss = 1.3007688522338867
Down Loss = 1.2142307758331299


 24%|██▍       | 5/21 [03:15<10:25, 39.09s/it]

Up Loss = 1.2562533617019653
Down Loss = 1.2299058437347412


 29%|██▊       | 6/21 [03:54<09:46, 39.09s/it]

Up Loss = 1.2052199840545654
Down Loss = 1.1961467266082764


 33%|███▎      | 7/21 [04:33<09:07, 39.10s/it]

Up Loss = 1.2077467441558838
Down Loss = 1.2081835269927979


 38%|███▊      | 8/21 [05:12<08:28, 39.11s/it]

Up Loss = 1.222405195236206
Down Loss = 1.219996690750122


 43%|████▎     | 9/21 [05:51<07:47, 38.97s/it]

Up Loss = 1.2638976573944092
Down Loss = 1.1924357414245605


 48%|████▊     | 10/21 [06:30<07:08, 38.97s/it]

Up Loss = 1.1955351829528809
Down Loss = 1.226625919342041


 52%|█████▏    | 11/21 [07:09<06:30, 39.01s/it]

Up Loss = 1.3245773315429688
Down Loss = 1.248091220855713


 57%|█████▋    | 12/21 [07:49<05:52, 39.19s/it]

Up Loss = 1.284676194190979
Down Loss = 1.229581356048584


 62%|██████▏   | 13/21 [08:28<05:13, 39.20s/it]

Up Loss = 1.2417155504226685
Down Loss = 1.202688217163086


 67%|██████▋   | 14/21 [09:07<04:34, 39.28s/it]

Up Loss = 1.2073469161987305
Down Loss = 1.2353609800338745


 71%|███████▏  | 15/21 [09:47<03:55, 39.30s/it]

Up Loss = 1.2869726419448853
Down Loss = 1.212803602218628


 76%|███████▌  | 16/21 [10:26<03:17, 39.41s/it]

Up Loss = 1.2479922771453857
Down Loss = 1.2143086194992065


 81%|████████  | 17/21 [11:06<02:37, 39.31s/it]

Up Loss = 1.2487962245941162
Down Loss = 1.2236363887786865


 86%|████████▌ | 18/21 [11:45<01:57, 39.26s/it]

Up Loss = 1.326598882675171
Down Loss = 1.2325628995895386


 90%|█████████ | 19/21 [12:23<01:18, 39.10s/it]

Up Loss = 1.2462924718856812
Down Loss = 1.2342960834503174


 95%|█████████▌| 20/21 [13:02<00:39, 39.03s/it]

Up Loss = 1.2803834676742554
Down Loss = 1.2468476295471191


100%|██████████| 21/21 [13:36<00:00, 38.90s/it]


Up Loss = 1.255540370941162
epoch 7: down loss = 1.2186636924743652, up loss = 1.2492381618136452


  0%|          | 0/21 [00:00<?, ?it/s]

Down Loss = 1.1898541450500488


  5%|▍         | 1/21 [00:39<13:11, 39.56s/it]

Up Loss = 1.2067822217941284
Down Loss = 1.2722256183624268


 10%|▉         | 2/21 [01:18<12:23, 39.16s/it]

Up Loss = 1.2540658712387085
Down Loss = 1.2226110696792603


 14%|█▍        | 3/21 [01:57<11:46, 39.24s/it]

Up Loss = 1.213937759399414
Down Loss = 1.1972975730895996


 19%|█▉        | 4/21 [02:37<11:08, 39.34s/it]

Up Loss = 1.185204267501831
Down Loss = 1.2225794792175293


 24%|██▍       | 5/21 [03:16<10:27, 39.23s/it]

Up Loss = 1.2133570909500122
Down Loss = 1.1826252937316895


 29%|██▊       | 6/21 [03:55<09:48, 39.20s/it]

Up Loss = 1.2016048431396484
Down Loss = 1.1832057237625122


 33%|███▎      | 7/21 [04:34<09:09, 39.23s/it]

Up Loss = 1.1863431930541992
Down Loss = 1.17429518699646


 38%|███▊      | 8/21 [05:14<08:30, 39.27s/it]

Up Loss = 1.1663068532943726
Down Loss = 1.1946334838867188


 43%|████▎     | 9/21 [05:53<07:51, 39.29s/it]

Up Loss = 1.1625512838363647
Down Loss = 1.1548843383789062


 48%|████▊     | 10/21 [06:33<07:13, 39.40s/it]

Up Loss = 1.167165756225586
Down Loss = 1.2200696468353271


 52%|█████▏    | 11/21 [07:12<06:34, 39.42s/it]

Up Loss = 1.1905003786087036
Down Loss = 1.2266156673431396


 57%|█████▋    | 12/21 [07:52<05:55, 39.47s/it]

Up Loss = 1.2499754428863525
Down Loss = 1.2018007040023804


 62%|██████▏   | 13/21 [08:32<05:17, 39.73s/it]

Up Loss = 1.1923949718475342
Down Loss = 1.165869951248169


 67%|██████▋   | 14/21 [09:13<04:40, 40.00s/it]

Up Loss = 1.2227551937103271
Down Loss = 1.1666855812072754


 71%|███████▏  | 15/21 [09:53<04:00, 40.13s/it]

Up Loss = 1.160707712173462
Down Loss = 1.1977789402008057


 76%|███████▌  | 16/21 [10:33<03:21, 40.20s/it]

Up Loss = 1.1950829029083252
Down Loss = 1.2278101444244385


 81%|████████  | 17/21 [11:14<02:40, 40.20s/it]

Up Loss = 1.2768144607543945
Down Loss = 1.145090937614441


 86%|████████▌ | 18/21 [11:54<02:00, 40.16s/it]

Up Loss = 1.151961326599121
Down Loss = 1.1574761867523193


 90%|█████████ | 19/21 [12:34<01:20, 40.36s/it]

Up Loss = 1.1830193996429443
Down Loss = 1.1831624507904053


 95%|█████████▌| 20/21 [13:15<00:40, 40.29s/it]

Up Loss = 1.1887036561965942
Down Loss = 1.188347578048706


100%|██████████| 21/21 [13:50<00:00, 39.54s/it]


Up Loss = 1.2200838327407837
epoch 8: down loss = 1.194043795267741, up loss = 1.1994913532620384


  0%|          | 0/21 [00:00<?, ?it/s]

Down Loss = 1.1239649057388306


  5%|▍         | 1/21 [00:40<13:31, 40.57s/it]

Up Loss = 1.1114068031311035
Down Loss = 1.1285814046859741


 10%|▉         | 2/21 [01:20<12:46, 40.35s/it]

Up Loss = 1.1335543394088745
Down Loss = 1.1426191329956055


 14%|█▍        | 3/21 [02:01<12:06, 40.35s/it]

Up Loss = 1.1474350690841675
Down Loss = 1.1404051780700684


 19%|█▉        | 4/21 [02:41<11:28, 40.48s/it]

Up Loss = 1.1439355611801147
Down Loss = 1.1559441089630127


 24%|██▍       | 5/21 [03:20<10:40, 40.00s/it]

Up Loss = 1.1493744850158691
Down Loss = 1.1492310762405396


 29%|██▊       | 6/21 [04:00<09:56, 39.77s/it]

Up Loss = 1.1562871932983398
Down Loss = 1.1276895999908447


 33%|███▎      | 7/21 [04:39<09:14, 39.59s/it]

Up Loss = 1.1261589527130127
Down Loss = 1.1389191150665283


 38%|███▊      | 8/21 [05:18<08:31, 39.33s/it]

Up Loss = 1.1330516338348389
Down Loss = 1.1266531944274902


 43%|████▎     | 9/21 [05:57<07:51, 39.26s/it]

Up Loss = 1.1237668991088867
Down Loss = 1.123106837272644


 48%|████▊     | 10/21 [06:36<07:10, 39.15s/it]

Up Loss = 1.1178672313690186
Down Loss = 1.1222487688064575


 52%|█████▏    | 11/21 [07:15<06:30, 39.06s/it]

Up Loss = 1.1094791889190674
Down Loss = 1.1283093690872192


 57%|█████▋    | 12/21 [07:54<05:51, 39.01s/it]

Up Loss = 1.1142430305480957
Down Loss = 1.1459243297576904


 62%|██████▏   | 13/21 [08:32<05:11, 38.98s/it]

Up Loss = 1.1295170783996582
Down Loss = 1.1411378383636475


 67%|██████▋   | 14/21 [09:12<04:33, 39.12s/it]

Up Loss = 1.1176173686981201
Down Loss = 1.1531102657318115


 71%|███████▏  | 15/21 [09:51<03:54, 39.15s/it]

Up Loss = 1.1240144968032837
Down Loss = 1.1752076148986816


 76%|███████▌  | 16/21 [10:30<03:15, 39.08s/it]

Up Loss = 1.1470580101013184
Down Loss = 1.1935148239135742


 81%|████████  | 17/21 [11:09<02:36, 39.07s/it]

Up Loss = 1.1264376640319824
Down Loss = 1.1513317823410034


 86%|████████▌ | 18/21 [11:48<01:57, 39.03s/it]

Up Loss = 1.1075340509414673
Down Loss = 1.1451865434646606


 90%|█████████ | 19/21 [12:27<01:17, 39.00s/it]

Up Loss = 1.1109466552734375
Down Loss = 1.1318778991699219


 95%|█████████▌| 20/21 [13:06<00:39, 39.09s/it]

Up Loss = 1.103123426437378
Down Loss = 1.1184825897216797


100%|██████████| 21/21 [13:41<00:00, 39.11s/it]


Up Loss = 1.0789158344268799
epoch 9: down loss = 1.141116494224185, up loss = 1.1243678558440435


  0%|          | 0/21 [00:00<?, ?it/s]

Down Loss = 1.0844552516937256


  5%|▍         | 1/21 [00:39<13:12, 39.63s/it]

Up Loss = 1.0592758655548096
Down Loss = 1.1316659450531006


 10%|▉         | 2/21 [01:18<12:26, 39.31s/it]

Up Loss = 1.0655643939971924
Down Loss = 1.178308367729187


 14%|█▍        | 3/21 [01:57<11:44, 39.16s/it]

Up Loss = 1.1303019523620605
Down Loss = 1.1266722679138184


 19%|█▉        | 4/21 [02:36<11:05, 39.16s/it]

Up Loss = 1.075299620628357
Down Loss = 1.1116411685943604


 24%|██▍       | 5/21 [03:15<10:25, 39.07s/it]

Up Loss = 1.06744384765625
Down Loss = 1.119348168373108


 29%|██▊       | 6/21 [03:54<09:46, 39.10s/it]

Up Loss = 1.0714502334594727
Down Loss = 1.1319916248321533


 33%|███▎      | 7/21 [04:34<09:08, 39.17s/it]

Up Loss = 1.0867223739624023
Down Loss = 1.1046545505523682


 38%|███▊      | 8/21 [05:13<08:28, 39.15s/it]

Up Loss = 1.0771266222000122
Down Loss = 1.2279958724975586


 43%|████▎     | 9/21 [05:52<07:51, 39.27s/it]

Up Loss = 1.1501812934875488
Down Loss = 1.1537387371063232


 48%|████▊     | 10/21 [06:31<07:09, 39.08s/it]

Up Loss = 1.0930819511413574
Down Loss = 1.111870527267456


 52%|█████▏    | 11/21 [07:10<06:31, 39.17s/it]

Up Loss = 1.062852382659912
Down Loss = 1.0828231573104858


 57%|█████▋    | 12/21 [07:49<05:51, 39.08s/it]

Up Loss = 1.0536096096038818
Down Loss = 1.1160852909088135


 62%|██████▏   | 13/21 [08:28<05:12, 39.05s/it]

Up Loss = 1.0765074491500854
Down Loss = 1.0727193355560303


 67%|██████▋   | 14/21 [09:07<04:33, 39.06s/it]

Up Loss = 1.043737769126892
Down Loss = 1.1256816387176514


 71%|███████▏  | 15/21 [09:46<03:54, 39.06s/it]

Up Loss = 1.0707558393478394
Down Loss = 1.1183123588562012


 76%|███████▌  | 16/21 [10:26<03:15, 39.10s/it]

Up Loss = 1.095773458480835
Down Loss = 1.1140731573104858


 81%|████████  | 17/21 [11:05<02:36, 39.17s/it]

Up Loss = 1.0710690021514893
Down Loss = 1.0910683870315552


 86%|████████▌ | 18/21 [11:45<01:58, 39.36s/it]

Up Loss = 1.0628304481506348
Down Loss = 1.0589466094970703


 90%|█████████ | 19/21 [12:25<01:19, 39.70s/it]

Up Loss = 1.0336803197860718
Down Loss = 1.0776350498199463


 95%|█████████▌| 20/21 [13:06<00:40, 40.02s/it]

Up Loss = 1.0541781187057495
Down Loss = 1.0849422216415405


100%|██████████| 21/21 [13:41<00:00, 39.13s/it]


Up Loss = 1.0614359378814697
epoch 10: down loss = 1.1154585565839494, up loss = 1.0744227852140154


  0%|          | 0/21 [00:00<?, ?it/s]

Down Loss = 1.0517644882202148


  5%|▍         | 1/21 [00:40<13:38, 40.94s/it]

Up Loss = 1.0369012355804443
Down Loss = 1.0372016429901123


 10%|▉         | 2/21 [01:21<12:53, 40.70s/it]

Up Loss = 1.0259488821029663
Down Loss = 1.0616880655288696


 14%|█▍        | 3/21 [02:01<12:09, 40.54s/it]

Up Loss = 1.0388174057006836
Down Loss = 1.0620955228805542


 19%|█▉        | 4/21 [02:41<11:26, 40.36s/it]

Up Loss = 1.0433942079544067
Down Loss = 1.0650670528411865


 24%|██▍       | 5/21 [03:22<10:48, 40.50s/it]

Up Loss = 1.0463266372680664
Down Loss = 1.0312464237213135


 29%|██▊       | 6/21 [04:02<10:06, 40.43s/it]

Up Loss = 1.01655912399292
Down Loss = 1.0528523921966553


 33%|███▎      | 7/21 [04:43<09:28, 40.59s/it]

Up Loss = 1.0374598503112793
Down Loss = 1.0683356523513794


 38%|███▊      | 8/21 [05:24<08:47, 40.55s/it]

Up Loss = 1.0492098331451416
Down Loss = 1.040966272354126


 43%|████▎     | 9/21 [06:05<08:07, 40.59s/it]

Up Loss = 1.0207674503326416
Down Loss = 1.1204077005386353


 48%|████▊     | 10/21 [06:45<07:25, 40.53s/it]

Up Loss = 1.0921223163604736
Down Loss = 1.0430080890655518


 52%|█████▏    | 11/21 [07:25<06:43, 40.32s/it]

Up Loss = 1.013181447982788
Down Loss = 1.0691676139831543


 57%|█████▋    | 12/21 [08:03<05:58, 39.83s/it]

Up Loss = 1.0444462299346924
Down Loss = 1.038240909576416


 62%|██████▏   | 13/21 [08:43<05:17, 39.63s/it]

Up Loss = 1.021756887435913
Down Loss = 1.0522334575653076


 67%|██████▋   | 14/21 [09:22<04:36, 39.46s/it]

Up Loss = 1.0261694192886353
Down Loss = 1.0431592464447021


 71%|███████▏  | 15/21 [10:00<03:55, 39.18s/it]

Up Loss = 1.0163075923919678
Down Loss = 1.0646960735321045


 76%|███████▌  | 16/21 [10:39<03:15, 39.06s/it]

Up Loss = 1.0358699560165405
Down Loss = 1.0777919292449951


 81%|████████  | 17/21 [11:18<02:36, 39.05s/it]

Up Loss = 1.0519945621490479
Down Loss = 1.038284182548523


 86%|████████▌ | 18/21 [11:57<01:57, 39.03s/it]

Up Loss = 1.019047737121582
Down Loss = 1.0254316329956055


 90%|█████████ | 19/21 [12:36<01:18, 39.00s/it]

Up Loss = 0.9875664710998535
Down Loss = 1.052702784538269


 95%|█████████▌| 20/21 [13:15<00:39, 39.10s/it]

Up Loss = 1.0260975360870361
Down Loss = 1.0683611631393433


100%|██████████| 21/21 [13:49<00:00, 39.50s/it]


Up Loss = 1.0319303274154663
epoch 11: down loss = 1.055462014107477, up loss = 1.0324702433177404


  0%|          | 0/21 [00:00<?, ?it/s]

Down Loss = 1.0712203979492188


  5%|▍         | 1/21 [00:38<12:52, 38.65s/it]

Up Loss = 1.0384669303894043
Down Loss = 1.045759677886963


 10%|▉         | 2/21 [01:17<12:21, 39.01s/it]

Up Loss = 0.9934582114219666
Down Loss = 1.006262183189392


 14%|█▍        | 3/21 [01:56<11:41, 38.96s/it]

Up Loss = 0.9782543182373047
Down Loss = 1.0890172719955444


 19%|█▉        | 4/21 [02:35<11:03, 39.05s/it]

Up Loss = 1.033366322517395
Down Loss = 1.0684864521026611


 24%|██▍       | 5/21 [03:15<10:24, 39.04s/it]

Up Loss = 1.0290651321411133
Down Loss = 1.0775666236877441


 29%|██▊       | 6/21 [03:54<09:46, 39.11s/it]

Up Loss = 1.008797287940979
Down Loss = 1.042752981185913


 33%|███▎      | 7/21 [04:33<09:08, 39.16s/it]

Up Loss = 0.9912079572677612
Down Loss = 1.0785582065582275


 38%|███▊      | 8/21 [05:12<08:28, 39.08s/it]

Up Loss = 1.022120714187622
Down Loss = 1.0564191341400146


 43%|████▎     | 9/21 [05:51<07:50, 39.17s/it]

Up Loss = 1.013680100440979
Down Loss = 1.038799524307251


 48%|████▊     | 10/21 [06:31<07:11, 39.24s/it]

Up Loss = 0.9996470808982849
Down Loss = 1.0443044900894165


 52%|█████▏    | 11/21 [07:10<06:32, 39.21s/it]

Up Loss = 1.0019803047180176
Down Loss = 1.037131905555725


 57%|█████▋    | 12/21 [07:49<05:53, 39.28s/it]

Up Loss = 0.9987239837646484
Down Loss = 1.0585296154022217


 62%|██████▏   | 13/21 [08:29<05:14, 39.26s/it]

Up Loss = 1.0171035528182983
Down Loss = 1.0455892086029053


 67%|██████▋   | 14/21 [09:08<04:35, 39.29s/it]

Up Loss = 0.9842837452888489
Down Loss = 1.0911707878112793


 71%|███████▏  | 15/21 [09:47<03:56, 39.36s/it]

Up Loss = 1.026277780532837
Down Loss = 1.1096479892730713


 76%|███████▌  | 16/21 [10:27<03:16, 39.32s/it]

Up Loss = 1.0094963312149048
Down Loss = 1.0937135219573975


 81%|████████  | 17/21 [11:06<02:37, 39.33s/it]

Up Loss = 1.0255420207977295
Down Loss = 1.0881590843200684


 86%|████████▌ | 18/21 [11:46<01:58, 39.39s/it]

Up Loss = 1.011293888092041
Down Loss = 1.0784469842910767


 90%|█████████ | 19/21 [12:25<01:18, 39.29s/it]

Up Loss = 1.0107680559158325
Down Loss = 1.0320810079574585


 95%|█████████▌| 20/21 [13:04<00:39, 39.39s/it]

Up Loss = 0.9786460399627686
Down Loss = 1.028914213180542


100%|██████████| 21/21 [13:38<00:00, 39.00s/it]


Up Loss = 0.9868826866149902
epoch 12: down loss = 1.0610729172116233, up loss = 1.0075744021506536


  0%|          | 0/21 [00:00<?, ?it/s]

Down Loss = 1.0247368812561035


  5%|▍         | 1/21 [00:38<12:56, 38.83s/it]

Up Loss = 0.9718309640884399
Down Loss = 1.0255767107009888


 10%|▉         | 2/21 [01:18<12:27, 39.32s/it]

Up Loss = 0.9560459852218628
Down Loss = 1.0944839715957642


 14%|█▍        | 3/21 [01:57<11:46, 39.25s/it]

Up Loss = 0.9905688166618347
Down Loss = 1.0878493785858154


 19%|█▉        | 4/21 [02:36<11:06, 39.22s/it]

Up Loss = 1.0128779411315918
Down Loss = 1.0620007514953613


 24%|██▍       | 5/21 [03:16<10:32, 39.54s/it]

Up Loss = 0.9740972518920898
Down Loss = 1.050309658050537


 29%|██▊       | 6/21 [03:56<09:54, 39.62s/it]

Up Loss = 0.9783550500869751
Down Loss = 1.0635828971862793


 33%|███▎      | 7/21 [04:36<09:17, 39.83s/it]

Up Loss = 0.9978643655776978
Down Loss = 1.0207045078277588


 38%|███▊      | 8/21 [05:17<08:40, 40.02s/it]

Up Loss = 0.9764441847801208
Down Loss = 1.051710844039917


 43%|████▎     | 9/21 [05:57<08:01, 40.10s/it]

Up Loss = 0.9902056455612183
Down Loss = 1.080944538116455


 48%|████▊     | 10/21 [06:38<07:22, 40.19s/it]

Up Loss = 0.9825273752212524
Down Loss = 1.0573322772979736


 52%|█████▏    | 11/21 [07:18<06:42, 40.21s/it]

Up Loss = 1.0146818161010742
Down Loss = 1.0647071599960327


 57%|█████▋    | 12/21 [07:59<06:03, 40.43s/it]

Up Loss = 0.9657896757125854
Down Loss = 1.0401883125305176


 62%|██████▏   | 13/21 [08:39<05:23, 40.47s/it]

Up Loss = 0.9745967388153076
Down Loss = 1.0641717910766602


 67%|██████▋   | 14/21 [09:20<04:43, 40.48s/it]

Up Loss = 0.9855077266693115
Down Loss = 1.0720359086990356


 71%|███████▏  | 15/21 [10:00<04:01, 40.33s/it]

Up Loss = 1.015534520149231
Down Loss = 1.036412000656128


 76%|███████▌  | 16/21 [10:40<03:21, 40.38s/it]

Up Loss = 0.9795832633972168
Down Loss = 1.04936945438385


 81%|████████  | 17/21 [11:21<02:41, 40.40s/it]

Up Loss = 0.9912271499633789
Down Loss = 1.0701823234558105


 86%|████████▌ | 18/21 [12:01<02:01, 40.36s/it]

Up Loss = 0.9659311771392822
Down Loss = 1.0734635591506958


 90%|█████████ | 19/21 [12:41<01:20, 40.23s/it]

Up Loss = 1.0071258544921875
Down Loss = 1.064132809638977


 95%|█████████▌| 20/21 [13:20<00:39, 39.90s/it]

Up Loss = 0.9812257289886475
Down Loss = 1.0461030006408691


100%|██████████| 21/21 [13:54<00:00, 39.75s/it]


Up Loss = 0.9798258543014526
epoch 13: down loss = 1.0571427969705491, up loss = 0.9853260517120361


  0%|          | 0/21 [00:00<?, ?it/s]

Down Loss = 1.0121456384658813


  5%|▍         | 1/21 [00:39<13:05, 39.29s/it]

Up Loss = 0.9547142386436462
Down Loss = 1.0288814306259155


 10%|▉         | 2/21 [01:18<12:29, 39.44s/it]

Up Loss = 0.9863457679748535
Down Loss = 1.0408167839050293


 14%|█▍        | 3/21 [01:58<11:52, 39.57s/it]

Up Loss = 0.9565832018852234
Down Loss = 1.027886152267456


 19%|█▉        | 4/21 [02:37<11:08, 39.35s/it]

Up Loss = 0.9643497467041016
Down Loss = 1.0653281211853027


 24%|██▍       | 5/21 [03:16<10:29, 39.37s/it]

Up Loss = 0.9610321521759033
Down Loss = 0.9928345680236816


 29%|██▊       | 6/21 [03:56<09:51, 39.47s/it]

Up Loss = 0.9515672922134399
Down Loss = 1.0746676921844482


 33%|███▎      | 7/21 [04:36<09:12, 39.44s/it]

Up Loss = 0.9607311487197876
Down Loss = 0.9925475120544434


 38%|███▊      | 8/21 [05:15<08:32, 39.44s/it]

Up Loss = 0.9509662389755249
Down Loss = 1.051243782043457


 43%|████▎     | 9/21 [05:54<07:51, 39.32s/it]

Up Loss = 0.9802550077438354
Down Loss = 0.996488630771637


 48%|████▊     | 10/21 [06:33<07:11, 39.20s/it]

Up Loss = 0.9538806080818176
Down Loss = 1.0492852926254272


 52%|█████▏    | 11/21 [07:12<06:30, 39.05s/it]

Up Loss = 0.9519320726394653
Down Loss = 1.0480482578277588


 57%|█████▋    | 12/21 [07:51<05:51, 39.07s/it]

Up Loss = 1.0038952827453613
Down Loss = 1.0168436765670776


 62%|██████▏   | 13/21 [08:30<05:12, 39.12s/it]

Up Loss = 0.9583865404129028
Down Loss = 1.020656943321228


 67%|██████▋   | 14/21 [09:09<04:33, 39.02s/it]

Up Loss = 0.9645981192588806
Down Loss = 1.002620816230774


 71%|███████▏  | 15/21 [09:47<03:53, 38.92s/it]

Up Loss = 0.936041533946991
Down Loss = 0.9912616610527039


 76%|███████▌  | 16/21 [10:27<03:15, 39.03s/it]

Up Loss = 0.9485357999801636
Down Loss = 1.0177726745605469


 81%|████████  | 17/21 [11:06<02:36, 39.02s/it]

Up Loss = 0.9499863386154175
Down Loss = 1.0474928617477417


 86%|████████▌ | 18/21 [11:45<01:56, 38.96s/it]

Up Loss = 1.0059176683425903
Down Loss = 1.0038245916366577


 90%|█████████ | 19/21 [12:23<01:17, 38.95s/it]

Up Loss = 0.9441556930541992
Down Loss = 0.9832793474197388


 95%|█████████▌| 20/21 [13:03<00:39, 39.00s/it]

Up Loss = 0.9505484700202942
Down Loss = 1.0068349838256836


100%|██████████| 21/21 [13:37<00:00, 38.91s/it]


Up Loss = 0.9609225988388062
epoch 14: down loss = 1.0224172103972662, up loss = 0.9616831200463432


  0%|          | 0/21 [00:00<?, ?it/s]

Down Loss = 0.9682633280754089


  5%|▍         | 1/21 [00:38<12:52, 38.60s/it]

Up Loss = 0.942138671875
Down Loss = 1.0029056072235107


 10%|▉         | 2/21 [01:17<12:16, 38.75s/it]

Up Loss = 0.9566919803619385
Down Loss = 0.9640412330627441


 14%|█▍        | 3/21 [01:56<11:42, 39.01s/it]

Up Loss = 0.9387770295143127
Down Loss = 0.9856676459312439


 19%|█▉        | 4/21 [02:36<11:06, 39.20s/it]

Up Loss = 0.9521693587303162
Down Loss = 0.9912139177322388


 24%|██▍       | 5/21 [03:15<10:24, 39.04s/it]

Up Loss = 0.9423769116401672
Down Loss = 0.982347846031189


 29%|██▊       | 6/21 [03:54<09:45, 39.05s/it]

Up Loss = 0.946466326713562
Down Loss = 0.9856443405151367


 33%|███▎      | 7/21 [04:33<09:06, 39.05s/it]

Up Loss = 0.9541985988616943
Down Loss = 0.9507447481155396


 38%|███▊      | 8/21 [05:11<08:26, 38.97s/it]

Up Loss = 0.9206249713897705
Down Loss = 0.9489458799362183


 43%|████▎     | 9/21 [05:50<07:47, 38.93s/it]

Up Loss = 0.9226548671722412
Down Loss = 0.9708945751190186


 48%|████▊     | 10/21 [06:30<07:09, 39.04s/it]

Up Loss = 0.9414607286453247
Down Loss = 0.9905869960784912


 52%|█████▏    | 11/21 [07:08<06:29, 38.96s/it]

Up Loss = 0.9626657962799072
Down Loss = 0.960228443145752


 57%|█████▋    | 12/21 [07:47<05:49, 38.87s/it]

Up Loss = 0.9361797571182251
Down Loss = 0.9495106935501099


 62%|██████▏   | 13/21 [08:27<05:12, 39.07s/it]

Up Loss = 0.9272522926330566
Down Loss = 0.9634774923324585


 67%|██████▋   | 14/21 [09:07<04:36, 39.45s/it]

Up Loss = 0.938002347946167
Down Loss = 0.9475343823432922


 71%|███████▏  | 15/21 [09:47<03:57, 39.65s/it]

Up Loss = 0.9212891459465027
Down Loss = 0.9481610059738159


 76%|███████▌  | 16/21 [10:27<03:18, 39.72s/it]

Up Loss = 0.9208165407180786
Down Loss = 0.9452780485153198


 81%|████████  | 17/21 [11:07<02:39, 39.95s/it]

Up Loss = 0.9231210350990295
Down Loss = 0.955782413482666


 86%|████████▌ | 18/21 [11:48<02:00, 40.03s/it]

Up Loss = 0.928436279296875
Down Loss = 0.9426319003105164


 90%|█████████ | 19/21 [12:28<01:20, 40.24s/it]

Up Loss = 0.9169675707817078
Down Loss = 0.9622205495834351


 95%|█████████▌| 20/21 [13:09<00:40, 40.26s/it]

Up Loss = 0.9360475540161133
Down Loss = 0.9374474287033081


100%|██████████| 21/21 [13:44<00:00, 39.25s/it]


Up Loss = 0.9067099094390869
epoch 15: down loss = 0.9644537369410197, up loss = 0.9350022701990037


  0%|          | 0/21 [00:00<?, ?it/s]

Down Loss = 0.9445286989212036


  5%|▍         | 1/21 [00:40<13:23, 40.15s/it]

Up Loss = 0.9225102663040161
Down Loss = 0.9684512615203857


 10%|▉         | 2/21 [01:20<12:43, 40.19s/it]

Up Loss = 0.9445434212684631
Down Loss = 0.9459406137466431


 14%|█▍        | 3/21 [02:00<12:03, 40.17s/it]

Up Loss = 0.9246608018875122
Down Loss = 0.928686261177063


 19%|█▉        | 4/21 [02:40<11:24, 40.26s/it]

Up Loss = 0.9048344492912292
Down Loss = 0.9240401983261108


 24%|██▍       | 5/21 [03:21<10:47, 40.45s/it]

Up Loss = 0.8986093997955322
Down Loss = 0.961356520652771


 29%|██▊       | 6/21 [04:01<10:04, 40.32s/it]

Up Loss = 0.9399195909500122
Down Loss = 0.9215950965881348


 33%|███▎      | 7/21 [04:41<09:21, 40.11s/it]

Up Loss = 0.9002512097358704
Down Loss = 0.9297942519187927


 38%|███▊      | 8/21 [05:20<08:37, 39.80s/it]

Up Loss = 0.9083471298217773
Down Loss = 0.9221631288528442


 43%|████▎     | 9/21 [05:59<07:55, 39.66s/it]

Up Loss = 0.9013670682907104
Down Loss = 0.9508518576622009


 48%|████▊     | 10/21 [06:39<07:14, 39.54s/it]

Up Loss = 0.9256950616836548
Down Loss = 0.9771779179573059


 52%|█████▏    | 11/21 [07:18<06:34, 39.41s/it]

Up Loss = 0.9510437250137329
Down Loss = 0.9964777231216431


 57%|█████▋    | 12/21 [07:57<05:53, 39.30s/it]

Up Loss = 0.9669716358184814
Down Loss = 0.9346427917480469


 62%|██████▏   | 13/21 [08:36<05:13, 39.15s/it]

Up Loss = 0.9067474603652954
Down Loss = 0.9490745067596436


 67%|██████▋   | 14/21 [09:15<04:33, 39.06s/it]

Up Loss = 0.9284818768501282
Down Loss = 0.9398898482322693


 71%|███████▏  | 15/21 [09:54<03:55, 39.29s/it]

Up Loss = 0.9090518355369568
Down Loss = 0.9243659973144531


 76%|███████▌  | 16/21 [10:34<03:16, 39.30s/it]

Up Loss = 0.9027585387229919
Down Loss = 0.9147158861160278


 81%|████████  | 17/21 [11:13<02:37, 39.42s/it]

Up Loss = 0.8953056931495667
Down Loss = 0.9385532140731812


 86%|████████▌ | 18/21 [11:53<01:58, 39.45s/it]

Up Loss = 0.91511470079422
Down Loss = 0.9306333661079407


 90%|█████████ | 19/21 [12:32<01:18, 39.35s/it]

Up Loss = 0.905809223651886
Down Loss = 0.9203726053237915


 95%|█████████▌| 20/21 [13:11<00:39, 39.35s/it]

Up Loss = 0.8959667086601257
Down Loss = 0.9247069358825684


100%|██████████| 21/21 [13:46<00:00, 39.36s/it]


Up Loss = 0.9038028717041016
epoch 16: down loss = 0.9403818420001439, up loss = 0.9167520318712506


  0%|          | 0/21 [00:00<?, ?it/s]

Down Loss = 0.9326671361923218


  5%|▍         | 1/21 [00:39<13:04, 39.23s/it]

Up Loss = 0.9115098714828491
Down Loss = 0.9212952852249146


 10%|▉         | 2/21 [01:18<12:24, 39.16s/it]

Up Loss = 0.904208779335022
Down Loss = 0.9076063632965088


 14%|█▍        | 3/21 [01:57<11:44, 39.15s/it]

Up Loss = 0.88543701171875
Down Loss = 0.9206892251968384


 19%|█▉        | 4/21 [02:36<11:02, 38.97s/it]

Up Loss = 0.8978924751281738
Down Loss = 0.9383308291435242


 24%|██▍       | 5/21 [03:15<10:25, 39.09s/it]

Up Loss = 0.9176621437072754
Down Loss = 0.8965587019920349


 29%|██▊       | 6/21 [03:54<09:46, 39.11s/it]

Up Loss = 0.8775029182434082
Down Loss = 0.9162015914916992


 33%|███▎      | 7/21 [04:33<09:07, 39.11s/it]

Up Loss = 0.8906963467597961
Down Loss = 0.9116599559783936


 38%|███▊      | 8/21 [05:13<08:30, 39.27s/it]

Up Loss = 0.895763099193573
Down Loss = 0.9159815311431885


 43%|████▎     | 9/21 [05:52<07:51, 39.28s/it]

Up Loss = 0.8975857496261597
Down Loss = 0.9189865589141846


 48%|████▊     | 10/21 [06:31<07:11, 39.25s/it]

Up Loss = 0.8969022035598755
Down Loss = 0.9010657072067261


 52%|█████▏    | 11/21 [07:11<06:32, 39.26s/it]

Up Loss = 0.8815586566925049
Down Loss = 0.9358094334602356


 57%|█████▋    | 12/21 [07:50<05:52, 39.22s/it]

Up Loss = 0.9158223867416382
Down Loss = 0.9222215414047241


 62%|██████▏   | 13/21 [08:29<05:13, 39.16s/it]

Up Loss = 0.9041599035263062
Down Loss = 0.9677647948265076


 67%|██████▋   | 14/21 [09:07<04:32, 38.95s/it]

Up Loss = 0.9387081861495972
Down Loss = 0.9177104830741882


 71%|███████▏  | 15/21 [09:46<03:53, 38.90s/it]

Up Loss = 0.8956961631774902
Down Loss = 0.9412530064582825


 76%|███████▌  | 16/21 [10:25<03:14, 38.92s/it]

Up Loss = 0.9169018268585205
Down Loss = 0.9401931166648865


 81%|████████  | 17/21 [11:04<02:35, 38.93s/it]

Up Loss = 0.9171451330184937
Down Loss = 0.8996994495391846


 86%|████████▌ | 18/21 [11:43<01:56, 39.00s/it]

Up Loss = 0.8776405453681946
Down Loss = 0.9172533750534058


 90%|█████████ | 19/21 [12:22<01:18, 39.03s/it]

Up Loss = 0.8926966786384583
Down Loss = 0.9213343858718872


 95%|█████████▌| 20/21 [13:02<00:39, 39.17s/it]

Up Loss = 0.9019105434417725
Down Loss = 0.9494084119796753


100%|██████████| 21/21 [13:36<00:00, 38.90s/it]


Up Loss = 0.9243971109390259
epoch 17: down loss = 0.9235090897196815, up loss = 0.9019903682527088


  0%|          | 0/21 [00:00<?, ?it/s]

Down Loss = 0.9085628390312195


  5%|▍         | 1/21 [00:39<13:18, 39.90s/it]

Up Loss = 0.8923318982124329
Down Loss = 0.887319028377533


 10%|▉         | 2/21 [01:19<12:38, 39.95s/it]

Up Loss = 0.8673726916313171
Down Loss = 0.9088932871818542


 14%|█▍        | 3/21 [02:00<12:00, 40.04s/it]

Up Loss = 0.8870331645011902
Down Loss = 0.9172384738922119


 19%|█▉        | 4/21 [02:40<11:23, 40.18s/it]

Up Loss = 0.9005952477455139
Down Loss = 0.9073423147201538


 24%|██▍       | 5/21 [03:20<10:43, 40.22s/it]

Up Loss = 0.88172447681427
Down Loss = 0.9302887916564941


 29%|██▊       | 6/21 [04:00<10:01, 40.11s/it]

Up Loss = 0.9089902639389038
Down Loss = 0.9053750038146973


 33%|███▎      | 7/21 [04:40<09:21, 40.11s/it]

Up Loss = 0.8815480470657349
Down Loss = 0.8954474925994873


 38%|███▊      | 8/21 [05:20<08:41, 40.12s/it]

Up Loss = 0.8734821081161499
Down Loss = 0.9148974418640137


 43%|████▎     | 9/21 [06:01<08:02, 40.17s/it]

Up Loss = 0.9004456996917725
Down Loss = 0.8957108855247498


 48%|████▊     | 10/21 [06:41<07:23, 40.28s/it]

Up Loss = 0.8711123466491699
Down Loss = 0.9073286056518555


 52%|█████▏    | 11/21 [07:21<06:41, 40.16s/it]

Up Loss = 0.8889757394790649
Down Loss = 0.9061446785926819


 57%|█████▋    | 12/21 [08:01<06:01, 40.16s/it]

Up Loss = 0.8855907320976257
Down Loss = 0.8865103721618652


 62%|██████▏   | 13/21 [08:41<05:21, 40.16s/it]

Up Loss = 0.86712646484375
Down Loss = 0.901993989944458


 67%|██████▋   | 14/21 [09:21<04:40, 40.06s/it]

Up Loss = 0.8780992031097412
Down Loss = 0.9067179560661316


 71%|███████▏  | 15/21 [10:01<03:59, 39.85s/it]

Up Loss = 0.8875411748886108
Down Loss = 0.8752050399780273


 76%|███████▌  | 16/21 [10:40<03:18, 39.66s/it]

Up Loss = 0.8540025353431702
Down Loss = 0.8933007717132568


 81%|████████  | 17/21 [11:19<02:37, 39.46s/it]

Up Loss = 0.8708782196044922
Down Loss = 0.9791406989097595


 86%|████████▌ | 18/21 [11:58<01:58, 39.34s/it]

Up Loss = 0.9556375741958618
Down Loss = 0.921501874923706


 90%|█████████ | 19/21 [12:37<01:18, 39.15s/it]

Up Loss = 0.8984602093696594
Down Loss = 0.952653169631958


 95%|█████████▌| 20/21 [13:15<00:39, 39.04s/it]

Up Loss = 0.9318602085113525
Down Loss = 0.9124162197113037


100%|██████████| 21/21 [13:50<00:00, 39.55s/it]


Up Loss = 0.8862473964691162
epoch 18: down loss = 0.9101899493308294, up loss = 0.8890026382037571


  0%|          | 0/21 [00:00<?, ?it/s]

Down Loss = 0.9055657386779785


  5%|▍         | 1/21 [00:38<12:56, 38.81s/it]

Up Loss = 0.8813619017601013
Down Loss = 0.8867979049682617


 10%|▉         | 2/21 [01:18<12:30, 39.51s/it]

Up Loss = 0.8663809299468994
Down Loss = 0.8912870287895203


 14%|█▍        | 3/21 [01:57<11:47, 39.30s/it]

Up Loss = 0.8713809251785278
Down Loss = 0.8920881748199463


 19%|█▉        | 4/21 [02:37<11:07, 39.29s/it]

Up Loss = 0.8718915581703186
Down Loss = 0.8901209831237793


 24%|██▍       | 5/21 [03:16<10:26, 39.15s/it]

Up Loss = 0.8735986948013306
Down Loss = 0.897381603717804


 29%|██▊       | 6/21 [03:55<09:48, 39.22s/it]

Up Loss = 0.8844000101089478
Down Loss = 0.883875846862793


 33%|███▎      | 7/21 [04:34<09:07, 39.10s/it]

Up Loss = 0.8708662986755371
Down Loss = 0.8681777715682983


 38%|███▊      | 8/21 [05:13<08:28, 39.12s/it]

Up Loss = 0.8429283499717712
Down Loss = 0.9336585998535156


 43%|████▎     | 9/21 [05:51<07:46, 38.91s/it]

Up Loss = 0.9104944467544556
Down Loss = 0.886012077331543


 48%|████▊     | 10/21 [06:30<07:07, 38.85s/it]

Up Loss = 0.8635478615760803
Down Loss = 0.9158377647399902


 52%|█████▏    | 11/21 [07:09<06:28, 38.87s/it]

Up Loss = 0.8929972648620605
Down Loss = 0.9067869186401367


 57%|█████▋    | 12/21 [07:48<05:50, 38.89s/it]

Up Loss = 0.8902654647827148
Down Loss = 0.8681906461715698


 62%|██████▏   | 13/21 [08:27<05:12, 39.02s/it]

Up Loss = 0.844252347946167
Down Loss = 0.8983526229858398


 67%|██████▋   | 14/21 [09:06<04:32, 38.97s/it]

Up Loss = 0.8788753747940063
Down Loss = 0.8924828171730042


 71%|███████▏  | 15/21 [09:45<03:53, 38.95s/it]

Up Loss = 0.8755800724029541
Down Loss = 0.8999364376068115


 76%|███████▌  | 16/21 [10:24<03:14, 38.92s/it]

Up Loss = 0.8798635601997375
Down Loss = 0.8934051990509033


 81%|████████  | 17/21 [11:04<02:36, 39.18s/it]

Up Loss = 0.8667960166931152
Down Loss = 0.9227851629257202


 86%|████████▌ | 18/21 [11:43<01:57, 39.15s/it]

Up Loss = 0.9001352190971375
Down Loss = 0.8816632032394409


 90%|█████████ | 19/21 [12:22<01:18, 39.29s/it]

Up Loss = 0.861221432685852
Down Loss = 0.9128678441047668


 95%|█████████▌| 20/21 [13:02<00:39, 39.27s/it]

Up Loss = 0.8962550163269043
Down Loss = 0.9261093139648438


100%|██████████| 21/21 [13:36<00:00, 38.87s/it]

Up Loss = 0.9044310450553894
epoch 19: down loss = 0.8977801743007842, up loss = 0.8775011329423814





In [17]:
torch.save(up_changer.state_dict(), 'tim_files/up_changer.pth')

In [9]:
up_changer = AdjacencyChangerUp(d=2,f_in=32).to(DEVICE)
up_changer.load_state_dict(torch.load('tim_files/up_changer.pth'))

<All keys matched successfully>

In [10]:
lr_X_test_dim1 = torch.load('model_autoencoder/test_hr_1.pt')
lr_X_test_dim2 = torch.load('model_autoencoder/test_hr_2.pt')


lr_test_X_all = torch.empty((112, 320, 32))
for i in range(len(lr_test_X_all)):
    a, b = lr_X_test_dim1[i], lr_X_test_dim2[i]
    lr_test_X_all[i] = torch.cat([a, b], dim=-1).view(-1, a.shape[-1])


In [11]:
pred = torch.empty((112, 268, 268))
up_changer.eval()
for i in range(len(lr_test)):
    pred[i] = up_changer(lr_test_X_all[i].to(DEVICE), lr_test[i].to(DEVICE))[-1].detach()