In [1]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import torch, sys, os, copy
sys.path.append('./')

from tqdm import tqdm as tqdm_base
def tqdm(*args, **kwargs):
    if hasattr(tqdm_base, '_instances'):
        for instance in list(tqdm_base._instances):
            tqdm_base._decr_instances(instance)
    return tqdm_base(*args, **kwargs)

from ncsnv2.models        import get_sigmas
from ncsnv2.models.ema    import EMAHelper
from ncsnv2.models.ncsnv2 import NCSNv2Deepest
from ncsnv2.losses        import get_optimizer
from ncsnv2.losses.dsm    import anneal_dsm_score_estimation

import scipy.io as sio
import random
from loaders          import Channels
from torch.utils.data import DataLoader

from dotmap import DotMap

from sample_generator import *

%load_ext autoreload
%autoreload 2


  warn(f"Failed to load image Python extension: {e}")


In [2]:
# Always !!!
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32       = False

# GPU
os.environ["CUDA_DEVICE_ORDER"]    = "PCI_BUS_ID";
os.environ["CUDA_VISIBLE_DEVICES"] = "0";

# Model config
config          = DotMap()
config.device   = 'cuda:0'
# Inner model
config.model.ema           = True
config.model.ema_rate      = 0.999
config.model.normalization = 'InstanceNorm++'
config.model.nonlinearity  = 'elu'
config.model.sigma_dist    = 'geometric'
config.model.num_classes   = 2311 # Number of train sigmas and 'N'
config.model.ngf           = 32

# Optimizer
config.optim.weight_decay  = 0.000 # No weight decay
config.optim.optimizer     = 'Adam'
config.optim.lr            = 0.0001
config.optim.beta1         = 0.9
config.optim.amsgrad       = False
config.optim.eps           = 0.001

# Training
config.training.batch_size     = 32
config.training.num_workers    = 4
config.training.n_epochs       = 40
config.training.anneal_power   = 2
config.training.log_all_sigmas = False
config.training.eval_freq      = 50 # In epochs

# Data
config.data.channel        = '3GPP' # Training and validation
config.data.channels       = 2 # {Re, Im}
config.data.num_pilots     = 8
config.data.noise_std      = 0.01 # 'Beta' in paper
config.data.image_size     = [64, 32] # Channel size = Nr x Nt
config.data.mixed_channels = False
config.data.norm_channels  = 'global' # Optional, no major impact
config.data.spacing_list   = [0.5] # Training and validation

# Universal seeds
train_seed, val_seed = 1234, 4321


In [4]:
# Get datasets and loaders for channels
# train_samples = 8000
# generator = sample_generator(train_samples, 4, config.data.image_size[0])

# H, y, x, j_indices, noise_sigma = generator.give_batch_data(config.data.image_size[1], snr_db_min=0,
#                                                                 snr_db_max=0, 
#                                                                 batch_size=train_samples,
#                                                                 correlated_flag=True, 
#                                                                 rho=0.6
#                                                                 )

# del y, x, j_indices  

# H = H.detach().numpy()
# H_complex = H[:,0:config.data.image_size[0], 0:config.data.image_size[1]] + 1j * H[:,config.data.image_size[0]:, 0:config.data.image_size[1]] 

##
# NT = 32
train_samples = 7500
mat_contents = sio.loadmat('data/H_bank_64.mat')
H = mat_contents['H_bank']
# H = torch.tensor(H[:, :, 0:config.NT])
H_complex = torch.tensor(H[:train_samples, :, :]).detach().numpy()#Pick up NT random users from 100.

#############################################
dataset     = Channels(train_seed, config, H = H_complex, norm=config.data.norm_channels)
dataloader  = DataLoader(dataset, batch_size=config.training.batch_size, 
                         shuffle=True, num_workers=config.training.num_workers,
                         drop_last=True)

######

# val_samples = 2000
# H_val, y, x, j_indices, noise_sigma = generator.give_batch_data(config.data.image_size[1], snr_db_min=0,
#                                                                 snr_db_max=0, 
#                                                                 batch_size=val_samples, 
#                                                                 correlated_flag=True, 
#                                                                 rho=0.6
#                                                                 )

# del y, x, j_indices  

# H_val = H_val.detach().numpy()
# H_val_complex = H_val[:,0:config.data.image_size[0], 0:config.data.image_size[1]] + 1j * H_val[:,config.data.image_size[0]:, 0:config.data.image_size[1]] 


# H = torch.tensor(H[:, :, 0:config.NT])
H_val_complex = torch.tensor(H[train_samples:9500, :, :]).detach().numpy()#Pick up NT random users from 100.
val_samples = H_val_complex.shape[0]


####################################################
# Create separate validation sets
val_datasets, val_loaders, val_iters = [], [], []
for idx in range(len(config.data.spacing_list)):
    # Validation config
    val_config = copy.deepcopy(config)
    val_config.data.spacing_list = [config.data.spacing_list[idx]]
    # Create locals
    val_datasets.append(Channels(val_seed, val_config, H = H_val_complex,  norm=config.data.norm_channels))
    val_loaders.append(DataLoader(
        val_datasets[-1], batch_size=len(val_datasets[-1]),
        shuffle=False, num_workers=0, drop_last=True))
    val_iters.append(iter(val_loaders[-1])) # For validation

In [5]:
# Construct pairwise distances
if False: # Set to true to follow [Song '20] exactly
    dist_matrix   = np.zeros((len(dataset), len(dataset)))
    flat_channels = dataset.channels.reshape((len(dataset), -1))
    for idx in tqdm(range(len(dataset))):
        dist_matrix[idx] = np.linalg.norm(
            flat_channels[idx][None, :] - flat_channels, axis=-1)
    dist_matrix_vec = dist_matrix.reshape(dist_matrix.shape[0] * dist_matrix.shape[1], 1).squeeze()
    dist_matrix_vec[np.argmax(dist_matrix.reshape(dist_matrix.shape[0] * dist_matrix.shape[1], 1).squeeze())]        

# Pre-determined values
# config.model.sigma_begin = 39.15 # !!! For CDL-D mixture
# config.model.sigma_begin = 27.77 # !!! For CDL-D lambda/2
config.model.sigma_begin = 30 # !!! For CDL-C
    
# Apply Song's third recommendation
if False:
    from scipy.stats import norm
    candidate_gamma = np.logspace(np.log10(0.9), np.log10(0.99999), 1000)
    gamma_criterion = np.zeros((len(candidate_gamma)))
    for idx, gamma in enumerate(candidate_gamma):
        gamma_criterion[idx] = \
            norm.cdf(np.sqrt(2 * np.prod(dataset[0]['H'].shape)) * (gamma - 1) + 3*gamma) - \
            norm.cdf(np.sqrt(2 * np.prod(dataset[0]['H'].shape)) * (gamma - 1) - 3*gamma)
    best_idx = np.argmin(np.abs(gamma_criterion - 0.5))
    
# Pre-determined
config.model.sigma_rate = 0.995 # !!! For everything
config.model.sigma_end  = config.model.sigma_begin * \
    config.model.sigma_rate ** (config.model.num_classes - 1)

# Choose the step size (epsilon) according to [Song '20]
candidate_steps = np.logspace(-13, -8, 1000)
step_criterion  = np.zeros((len(candidate_steps)))
gamma_rate      = 1 / config.model.sigma_rate
for idx, step in enumerate(candidate_steps):
    step_criterion[idx] = (1 - step / config.model.sigma_end ** 2) \
        ** (2 * config.model.num_classes) * (gamma_rate ** 2 -
            2 * step / (config.model.sigma_end ** 2 - config.model.sigma_end ** 2 * (
                1 - step / config.model.sigma_end ** 2) ** 2)) + \
            2 * step / (config.model.sigma_end ** 2 - config.model.sigma_end ** 2 * (
                1 - step / config.model.sigma_end ** 2) ** 2)
best_idx = np.argmin(np.abs(step_criterion - 1.))
config.model.step_size = candidate_steps[best_idx]


In [6]:
# Get a model
diffuser = NCSNv2Deepest(config)
diffuser = diffuser.cuda()
# Get optimizer
optimizer = get_optimizer(config, diffuser.parameters())

In [7]:
val_iters[0]

<torch.utils.data.dataloader._SingleProcessDataLoaderIter at 0x7ff60567bf40>

In [8]:
# Counter
start_epoch = 0
step = 0
if config.model.ema:
    ema_helper = EMAHelper(mu=config.model.ema_rate)
    ema_helper.register(diffuser)

# Get a collection of sigma values
sigmas = get_sigmas(config)

# Always the same initial points and data for validation
val_H_list = []
for idx in range(len(config.data.spacing_list)):
    val_sample = next(val_iters[idx])
    val_H_list.append(val_sample['H_herm'].cuda())

# More logging
config.log_path = 'models/\
numLambdas%d_lambdaMin%.1f_lambdaMax%.1f_sigmaT%.1f' % (
    len(config.data.spacing_list), np.min(config.data.spacing_list),
    np.max(config.data.spacing_list), config.model.sigma_begin)
if not os.path.exists(config.log_path):
    os.makedirs(config.log_path)
# No sigma logging
hook = test_hook = None

# Logged metrics
train_loss, val_loss  = [], []
val_errors, val_epoch = [], []

for epoch in tqdm(range(start_epoch, config.training.n_epochs)):
    for i, sample in tqdm(enumerate(dataloader)):
        # Safety check
        diffuser.train()
        step += 1
        
        # Move data to device
        for key in sample:
            sample[key] = sample[key].cuda()
        
        # Get loss on Hermitian channels
        loss = anneal_dsm_score_estimation(
            diffuser, sample['H_herm'], sigmas, None, 
            config.training.anneal_power, hook)
        
        # Keep a running loss
        if step == 1:
            running_loss = loss.item()
        else:
            running_loss = 0.99 * running_loss + 0.01 * loss.item()
        # Log
        train_loss.append(loss.item())
        
        # Step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # EMA update
        if config.model.ema:
            ema_helper.update(diffuser)
            
        # Verbose
        if step % 100 == 0:
            if config.model.ema:
                val_score = ema_helper.ema_copy(diffuser)
            else:
                val_score = diffuser
            
            # For each validation setup
            local_val_losses = []
            for idx in range(len(config.data.spacing_list)):
                with torch.no_grad():
                    val_dsm_loss = \
                        anneal_dsm_score_estimation(
                            val_score, val_H_list[idx],
                            sigmas, None,
                            config.training.anneal_power,
                            hook=test_hook)
                # Store
                local_val_losses.append(val_dsm_loss.item())
            # Sanity delete
            del val_score
            # Log
            val_loss.append(local_val_losses)
                
            # Print
            if len(local_val_losses) == 1:
                print('Epoch %d, Step %d, Train Loss (EMA) %.3f, \
    Val. Loss %.3f' % (
                    epoch, step, running_loss, 
                    local_val_losses[0]))
            elif len(local_val_losses) == 2:
                print('Epoch %d, Step %d, Train Loss (EMA) %.3f, \
    Val. Loss (Split) %.3f %.3f' % (
                    epoch, step, running_loss, 
                    local_val_losses[0], local_val_losses[1]))
        
# Save snapshot
torch.save({'model_state': diffuser.state_dict(),
            'optim_state': optimizer.state_dict(),
            'config': config,
            'train_loss': train_loss,
            'val_loss': val_loss,
            'val_errors': val_errors,
            'val_epoch': val_epoch}, 
   os.path.join(config.log_path, 'final_model_3gpp_64.pt'))

103it [00:07,  9.18it/s]0:00<?, ?it/s]

Epoch 0, Step 100, Train Loss (EMA) 2110.248,     Val. Loss 2489.627


203it [00:13,  9.28it/s]

Epoch 0, Step 200, Train Loss (EMA) 1665.342,     Val. Loss 2227.920


234it [00:15, 14.83it/s]
68it [00:04,  8.82it/s]00:15<10:18, 15.86s/it]

Epoch 1, Step 300, Train Loss (EMA) 1496.231,     Val. Loss 2038.642


168it [00:10,  9.22it/s]

Epoch 1, Step 400, Train Loss (EMA) 1393.759,     Val. Loss 1878.101


234it [00:14, 16.20it/s]
34it [00:02,  9.11it/s]00:30<09:33, 15.10s/it]

Epoch 2, Step 500, Train Loss (EMA) 1373.710,     Val. Loss 1774.015


134it [00:08,  9.23it/s]

Epoch 2, Step 600, Train Loss (EMA) 1370.615,     Val. Loss 1686.600


234it [00:14, 15.72it/s]
  8%|▊         | 3/40 [00:45<09:17, 15.06s/it]

Epoch 2, Step 700, Train Loss (EMA) 1339.047,     Val. Loss 1591.065


100it [00:06,  9.05it/s]

Epoch 3, Step 800, Train Loss (EMA) 1325.777,     Val. Loss 1555.008


200it [00:12,  9.10it/s]

Epoch 3, Step 900, Train Loss (EMA) 1313.155,     Val. Loss 1503.526


234it [00:14, 16.07it/s]
66it [00:04,  9.10it/s]01:00<08:56, 14.91s/it]

Epoch 4, Step 1000, Train Loss (EMA) 1324.405,     Val. Loss 1474.070


166it [00:10,  9.20it/s]

Epoch 4, Step 1100, Train Loss (EMA) 1323.089,     Val. Loss 1438.519


234it [00:14, 16.19it/s]
32it [00:02,  9.02it/s]01:14<08:37, 14.79s/it]

Epoch 5, Step 1200, Train Loss (EMA) 1319.993,     Val. Loss 1411.033


132it [00:08,  9.23it/s]

Epoch 5, Step 1300, Train Loss (EMA) 1303.935,     Val. Loss 1388.518


232it [00:14,  9.22it/s]

Epoch 5, Step 1400, Train Loss (EMA) 1301.131,     Val. Loss 1374.804


234it [00:14, 15.73it/s]
98it [00:06,  9.16it/s]01:29<08:25, 14.86s/it]

Epoch 6, Step 1500, Train Loss (EMA) 1308.065,     Val. Loss 1390.317


198it [00:12,  9.29it/s]

Epoch 6, Step 1600, Train Loss (EMA) 1293.928,     Val. Loss 1345.402


234it [00:14, 16.32it/s]
64it [00:04,  9.08it/s]01:44<08:06, 14.74s/it]

Epoch 7, Step 1700, Train Loss (EMA) 1267.218,     Val. Loss 1352.702


164it [00:10,  9.21it/s]

Epoch 7, Step 1800, Train Loss (EMA) 1276.967,     Val. Loss 1319.983


234it [00:14, 16.19it/s]
30it [00:02,  9.08it/s]01:58<07:49, 14.69s/it]

Epoch 8, Step 1900, Train Loss (EMA) 1281.597,     Val. Loss 1321.518


130it [00:08,  9.22it/s]

Epoch 8, Step 2000, Train Loss (EMA) 1280.671,     Val. Loss 1358.045


230it [00:14,  9.51it/s]

Epoch 8, Step 2100, Train Loss (EMA) 1276.870,     Val. Loss 1334.858


234it [00:14, 16.16it/s]
96it [00:05,  8.85it/s]02:13<07:34, 14.66s/it]

Epoch 9, Step 2200, Train Loss (EMA) 1278.860,     Val. Loss 1305.707


196it [00:12,  9.17it/s]

Epoch 9, Step 2300, Train Loss (EMA) 1277.381,     Val. Loss 1313.181


234it [00:14, 16.29it/s]
62it [00:04,  8.36it/s][02:27<07:18, 14.61s/it]

Epoch 10, Step 2400, Train Loss (EMA) 1259.874,     Val. Loss 1291.757


162it [00:10,  9.27it/s]

Epoch 10, Step 2500, Train Loss (EMA) 1260.840,     Val. Loss 1311.625


234it [00:14, 16.39it/s]
28it [00:02,  9.07it/s][02:42<07:01, 14.55s/it]

Epoch 11, Step 2600, Train Loss (EMA) 1258.374,     Val. Loss 1283.450


128it [00:08,  9.11it/s]

Epoch 11, Step 2700, Train Loss (EMA) 1253.754,     Val. Loss 1295.005


228it [00:14,  9.20it/s]

Epoch 11, Step 2800, Train Loss (EMA) 1256.883,     Val. Loss 1274.529


234it [00:14, 15.70it/s]
94it [00:05,  9.18it/s][02:57<06:51, 14.69s/it]

Epoch 12, Step 2900, Train Loss (EMA) 1263.055,     Val. Loss 1259.356


194it [00:11,  9.30it/s]

Epoch 12, Step 3000, Train Loss (EMA) 1253.422,     Val. Loss 1275.622


234it [00:14, 16.49it/s]
60it [00:04,  9.08it/s][03:11<06:33, 14.58s/it]

Epoch 13, Step 3100, Train Loss (EMA) 1256.146,     Val. Loss 1288.296


160it [00:10,  9.20it/s]

Epoch 13, Step 3200, Train Loss (EMA) 1249.346,     Val. Loss 1270.689


234it [00:14, 16.12it/s]
26it [00:01,  9.08it/s][03:26<06:19, 14.60s/it]

Epoch 14, Step 3300, Train Loss (EMA) 1257.745,     Val. Loss 1264.868


126it [00:08,  9.22it/s]

Epoch 14, Step 3400, Train Loss (EMA) 1264.316,     Val. Loss 1265.468


226it [00:14,  9.23it/s]

Epoch 14, Step 3500, Train Loss (EMA) 1253.413,     Val. Loss 1253.771


234it [00:14, 15.84it/s]
92it [00:05,  9.18it/s][03:41<06:07, 14.69s/it]

Epoch 15, Step 3600, Train Loss (EMA) 1241.925,     Val. Loss 1284.962


192it [00:11,  9.29it/s]

Epoch 15, Step 3700, Train Loss (EMA) 1263.640,     Val. Loss 1256.818


234it [00:14, 16.42it/s]
59it [00:03,  9.15it/s][03:55<05:50, 14.60s/it]

Epoch 16, Step 3800, Train Loss (EMA) 1255.652,     Val. Loss 1270.849


159it [00:10,  9.21it/s]

Epoch 16, Step 3900, Train Loss (EMA) 1254.642,     Val. Loss 1266.799


234it [00:14, 16.23it/s]
24it [00:01,  9.05it/s][04:10<05:35, 14.58s/it]

Epoch 17, Step 4000, Train Loss (EMA) 1224.568,     Val. Loss 1285.600


124it [00:08,  9.21it/s]

Epoch 17, Step 4100, Train Loss (EMA) 1232.785,     Val. Loss 1278.444


224it [00:14,  9.22it/s]

Epoch 17, Step 4200, Train Loss (EMA) 1219.205,     Val. Loss 1302.586


234it [00:14, 15.70it/s]
90it [00:05,  9.09it/s][04:25<05:23, 14.72s/it]

Epoch 18, Step 4300, Train Loss (EMA) 1221.655,     Val. Loss 1245.603


190it [00:11,  9.19it/s]

Epoch 18, Step 4400, Train Loss (EMA) 1241.374,     Val. Loss 1247.617


234it [00:14, 16.14it/s]
56it [00:03,  9.16it/s][04:39<05:08, 14.69s/it]

Epoch 19, Step 4500, Train Loss (EMA) 1245.351,     Val. Loss 1251.928


156it [00:09,  9.27it/s]

Epoch 19, Step 4600, Train Loss (EMA) 1247.545,     Val. Loss 1239.051


234it [00:14, 16.34it/s]
22it [00:01,  8.99it/s][04:54<04:52, 14.62s/it]

Epoch 20, Step 4700, Train Loss (EMA) 1232.248,     Val. Loss 1247.436


122it [00:07,  9.20it/s]

Epoch 20, Step 4800, Train Loss (EMA) 1250.549,     Val. Loss 1254.644


222it [00:14,  9.21it/s]

Epoch 20, Step 4900, Train Loss (EMA) 1246.207,     Val. Loss 1253.512


234it [00:14, 15.74it/s]
88it [00:05,  9.04it/s][05:09<04:39, 14.74s/it]

Epoch 21, Step 5000, Train Loss (EMA) 1216.244,     Val. Loss 1256.410


188it [00:11,  8.39it/s]

Epoch 21, Step 5100, Train Loss (EMA) 1221.009,     Val. Loss 1199.329


234it [00:14, 16.10it/s]
54it [00:03,  9.10it/s][05:23<04:24, 14.71s/it]

Epoch 22, Step 5200, Train Loss (EMA) 1223.699,     Val. Loss 1265.360


154it [00:09,  9.18it/s]

Epoch 22, Step 5300, Train Loss (EMA) 1228.568,     Val. Loss 1204.359


234it [00:14, 16.25it/s]
20it [00:01,  9.00it/s][05:38<04:09, 14.66s/it]

Epoch 23, Step 5400, Train Loss (EMA) 1225.635,     Val. Loss 1214.619


120it [00:07,  9.22it/s]

Epoch 23, Step 5500, Train Loss (EMA) 1223.654,     Val. Loss 1222.227


220it [00:13,  9.20it/s]

Epoch 23, Step 5600, Train Loss (EMA) 1228.461,     Val. Loss 1238.422


234it [00:14, 15.78it/s]
86it [00:05,  9.17it/s][05:53<03:55, 14.75s/it]

Epoch 24, Step 5700, Train Loss (EMA) 1222.914,     Val. Loss 1240.734


186it [00:11,  8.84it/s]

Epoch 24, Step 5800, Train Loss (EMA) 1218.419,     Val. Loss 1238.736


234it [00:14, 16.36it/s]
52it [00:03,  9.04it/s][06:07<03:39, 14.65s/it]

Epoch 25, Step 5900, Train Loss (EMA) 1209.650,     Val. Loss 1194.981


152it [00:09,  9.21it/s]

Epoch 25, Step 6000, Train Loss (EMA) 1214.071,     Val. Loss 1236.041


234it [00:14, 16.21it/s]
18it [00:01,  8.80it/s][06:22<03:24, 14.63s/it]

Epoch 26, Step 6100, Train Loss (EMA) 1187.392,     Val. Loss 1226.062


118it [00:07,  9.22it/s]

Epoch 26, Step 6200, Train Loss (EMA) 1208.283,     Val. Loss 1246.663


218it [00:13,  9.19it/s]

Epoch 26, Step 6300, Train Loss (EMA) 1207.397,     Val. Loss 1198.156


234it [00:14, 15.71it/s]
84it [00:05,  9.08it/s][06:37<03:11, 14.74s/it]

Epoch 27, Step 6400, Train Loss (EMA) 1222.904,     Val. Loss 1221.794


184it [00:11,  8.77it/s]

Epoch 27, Step 6500, Train Loss (EMA) 1218.725,     Val. Loss 1239.720


234it [00:14, 16.13it/s]
50it [00:03,  9.05it/s][06:52<02:56, 14.71s/it]

Epoch 28, Step 6600, Train Loss (EMA) 1227.591,     Val. Loss 1237.428


150it [00:09,  9.22it/s]

Epoch 28, Step 6700, Train Loss (EMA) 1217.377,     Val. Loss 1209.030


234it [00:14, 16.16it/s]
16it [00:01,  8.80it/s][07:06<02:41, 14.68s/it]

Epoch 29, Step 6800, Train Loss (EMA) 1213.499,     Val. Loss 1238.647


116it [00:07,  9.25it/s]

Epoch 29, Step 6900, Train Loss (EMA) 1215.544,     Val. Loss 1252.011


216it [00:13,  9.24it/s]

Epoch 29, Step 7000, Train Loss (EMA) 1200.249,     Val. Loss 1255.417


234it [00:14, 15.79it/s]
82it [00:05,  9.13it/s][07:21<02:27, 14.76s/it]

Epoch 30, Step 7100, Train Loss (EMA) 1188.798,     Val. Loss 1274.186


182it [00:11,  8.44it/s]

Epoch 30, Step 7200, Train Loss (EMA) 1206.816,     Val. Loss 1204.625


234it [00:14, 16.27it/s]
48it [00:03,  9.11it/s][07:36<02:12, 14.68s/it]

Epoch 31, Step 7300, Train Loss (EMA) 1209.681,     Val. Loss 1233.842


148it [00:09,  9.13it/s]

Epoch 31, Step 7400, Train Loss (EMA) 1219.155,     Val. Loss 1216.010


234it [00:14, 16.17it/s]
14it [00:01,  8.65it/s][07:50<01:57, 14.66s/it]

Epoch 32, Step 7500, Train Loss (EMA) 1205.837,     Val. Loss 1240.499


114it [00:07,  9.12it/s]

Epoch 32, Step 7600, Train Loss (EMA) 1208.378,     Val. Loss 1195.259


214it [00:13,  9.16it/s]

Epoch 32, Step 7700, Train Loss (EMA) 1228.600,     Val. Loss 1240.715


234it [00:15, 15.53it/s]
80it [00:05,  9.08it/s][08:05<01:43, 14.82s/it]

Epoch 33, Step 7800, Train Loss (EMA) 1212.209,     Val. Loss 1251.340


180it [00:11,  9.17it/s]

Epoch 33, Step 7900, Train Loss (EMA) 1208.114,     Val. Loss 1224.079


234it [00:14, 16.08it/s]
46it [00:03,  9.05it/s][08:20<01:28, 14.78s/it]

Epoch 34, Step 8000, Train Loss (EMA) 1205.259,     Val. Loss 1177.761


146it [00:09,  9.14it/s]

Epoch 34, Step 8100, Train Loss (EMA) 1193.429,     Val. Loss 1211.695


234it [00:14, 16.08it/s]
12it [00:01,  8.53it/s][08:35<01:13, 14.75s/it]

Epoch 35, Step 8200, Train Loss (EMA) 1194.197,     Val. Loss 1219.988


112it [00:07,  9.25it/s]

Epoch 35, Step 8300, Train Loss (EMA) 1198.377,     Val. Loss 1198.745


212it [00:13,  9.26it/s]

Epoch 35, Step 8400, Train Loss (EMA) 1193.405,     Val. Loss 1214.414


234it [00:14, 15.88it/s]
78it [00:04,  9.07it/s][08:50<00:59, 14.79s/it]

Epoch 36, Step 8500, Train Loss (EMA) 1177.915,     Val. Loss 1195.784


178it [00:11,  9.49it/s]

Epoch 36, Step 8600, Train Loss (EMA) 1175.796,     Val. Loss 1216.014


234it [00:14, 16.56it/s]
44it [00:03,  9.09it/s][09:04<00:43, 14.63s/it]

Epoch 37, Step 8700, Train Loss (EMA) 1190.390,     Val. Loss 1194.076


144it [00:09,  9.17it/s]

Epoch 37, Step 8800, Train Loss (EMA) 1201.131,     Val. Loss 1182.627


234it [00:14, 16.21it/s]
10it [00:01,  8.20it/s][09:18<00:29, 14.61s/it]

Epoch 38, Step 8900, Train Loss (EMA) 1182.741,     Val. Loss 1219.599


110it [00:07,  9.15it/s]

Epoch 38, Step 9000, Train Loss (EMA) 1176.191,     Val. Loss 1220.218


210it [00:13,  9.23it/s]

Epoch 38, Step 9100, Train Loss (EMA) 1186.599,     Val. Loss 1197.948


234it [00:14, 15.63it/s]
76it [00:04,  9.08it/s][09:34<00:14, 14.76s/it]

Epoch 39, Step 9200, Train Loss (EMA) 1202.828,     Val. Loss 1177.023


176it [00:11,  9.17it/s]

Epoch 39, Step 9300, Train Loss (EMA) 1195.533,     Val. Loss 1212.964


234it [00:14, 16.11it/s]
100%|██████████| 40/40 [09:48<00:00, 14.72s/it]
