In [49]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [50]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
import os

from tqdm import tqdm



import torchvision
from torchvision import transforms
from torch.utils.data import Subset, Dataset, DataLoader

import argparse


from utils.script_utils import select_dataset, init_model, setup_dataset_training, save_checkpoint_from, upload_checkpoint_in
from utils.loss import wasser_loss, entropy_limit_loss, kl_loss_from_ready
from main_utils import get_models_class_list, get_base_model_parameters


from timeit import default_timer as timer

In [51]:
# timer_start = timer() # start timer

torch.manual_seed(0)
# Defaults
GOOD_DATASET_TYPE = ['MNIST', 'CIFAR10', 'CELEBA', 'FMNIST', 'FASHIONMNIST']
GOOD_MODEL_TYPE = ['VAE', 'AE', 'LRAE', 'IRMAE']
GOOD_ARCHITECTURE_TYPE = ['V1', 'NIPS']


### Setup script

In [52]:
## Main work parameters
DEVICE = 'cuda:2'
DEVICE = 'cpu'
MODEL_TYPE = 'AE'

ARCHITECTURE_TYPE = 'NIPS'
DATASET_TYPE = 'MNIST'

ALPHA = 0
BATCH_SIZE = 256

EPOCHS = 5
# EPOCHS = 51
LEARNING_RATE = 1e-4

N_LATENT = 8

# for LRAE
N_BINS = 20


In [53]:
#### setup runs
print("Setup runs")
if DATASET_TYPE.upper() in ['MNIST', 'FMNIST', 'FASHIONMNIST']:
    # MODEL_NAME_PREF = f'test_bl_NIPS_{BATCH_SIZE}_{LEARNING_RATE}__'
    MODEL_NAME_PREF = f'test_NIPS__'
    SAVE_DIR = 'test_NIPS/data_n'
    
elif DATASET_TYPE.upper() in ['CIFAR10']:
    MODEL_NAME_PREF = 'test_NIPS__'
    SAVE_DIR = 'test_NIPS'
    
elif DATASET_TYPE.upper() in ['CELEBA']: 
    MODEL_NAME_PREF = f'test1_NIPS__'
    SAVE_DIR = 'test_NIPS'
else:
   print("Warning! the default run setups was not setuped!")
   
MODEL_NAME_PREF = 'Test_'
SAVE_DIR = ''
################### 


Setup runs


In [54]:
# some const part from run_train 
models_class_list = get_models_class_list(DATASET_TYPE, ARCHITECTURE_TYPE) 

   
### Model parameters

# setup model parameters
models_params = get_base_model_parameters(DATASET_TYPE, ARCHITECTURE_TYPE)
#setup some parameters!!
if N_LATENT != -1:
    models_params['BOTTLENECK'] = N_LATENT
###########

BOTTLENECK =  models_params['BOTTLENECK']
C_H_W = models_params['C_H_W']

   
# other Model parameters
NONLINEARITY = nn.ReLU()
models_params['NONLINEARITY'] = nn.ReLU()
###


# LRAE parameters
# N_BINS, DROPOUT, TEMP, SAMPLING = 20, 0.0, 0.5, 'gumbell'
DROPOUT, TEMP, SAMPLING = 0.0, 0.5, 'gumbell'

models_params = models_params | {'N_BINS': N_BINS, 'DROPOUT':DROPOUT, 'SAMPLING':SAMPLING, 'TEMP': TEMP}
##


TRAIN_SIZE = -1
TEST_SIZE = -1


# EPOCH_SAVE = 50 # save and remain
EPOCH_SAVE = 25 # save and remain


EPOCH_SAVE_BACKUP = 5 # save and rewrite 
SHOW_LOSS_BACKUP = 5 # save and rewrite 




### setup main loss
if MODEL_TYPE in ['VAE', 'LRAE']:
    main_loss = torch.nn.functional.binary_cross_entropy
    main_loss_str = 'torch.nn.functional.binary_cross_entropy'
else: # AE
    main_loss = torch.nn.functional.mse_loss
    main_loss_str = 'torch.nn.functional.mse_loss'
    
print(f"'{main_loss_str}'\t will be used as main_loss") 
##########################   





### setup additional loss
if MODEL_TYPE.upper() in ['VAE']:
    additional_loss, add_loss_str = kl_loss_from_ready, 'kl_loss_from_ready'
elif  MODEL_TYPE.upper() in ['LRAE']:
    # additional_loss, add_loss_str = wasser_loss, 'wasser_loss'
    additional_loss, add_loss_str= entropy_limit_loss, 'entropy_limit_loss'
else: # LRAE
    additional_loss, add_loss_str = None, 'None' 
   

print(f"'{add_loss_str}'\t will be used as additional_loss")
##########################




models were downloaded from 'models.NIPS_R1AE_MNIST'
'torch.nn.functional.mse_loss'	 will be used as main_loss
'None'	 will be used as additional_loss


In [55]:
def print_params(param_list, param_names_list):
    for param_name, param in zip(param_names_list, param_list):
        print(f"{param_name}: {param}")
    print()
    

# Show input data
print('Input script data', '\n')
print('Main parameters:')
in_param_list = [SAVE_DIR, DEVICE, MODEL_TYPE, DATASET_TYPE,  ARCHITECTURE_TYPE, BOTTLENECK, EPOCHS]
in_param__names_list = ['SAVE_DIR', 'DEVICE', 'MODEL_TYPE', 'DATASET_TYPE', 'ARCHITECTURE_TYPE', 'BOTTLENECK', 'EPOCHS']
print_params(in_param_list, in_param__names_list)
print()
print()


print('All model parameters:')
print_params(models_params.values(), models_params.keys())
print()

print('Training parameters:')
in_param_list = [BATCH_SIZE, LEARNING_RATE, ALPHA,  EPOCHS]
in_param__names_list = ['BATCH_SIZE', 'LEARNING_RATE', 'ALPHA', 'EPOCHS']
print_params(in_param_list, in_param__names_list)
print()

print('Dataset parameters:')
in_param_list = [DATASET_TYPE, TRAIN_SIZE, TEST_SIZE, BATCH_SIZE]
in_param__names_list = ['DATASET_TYPE', 'TRAIN_SIZE', 'TEST_SIZE', 'BATCH_SIZE']
print_params(in_param_list, in_param__names_list)
print()


## other parameters
NUM_WORKERS = 32

# Other parameters
print('Other parameters')
other_param_list = [NUM_WORKERS, EPOCH_SAVE, EPOCH_SAVE_BACKUP, SHOW_LOSS_BACKUP]
other_param_names_list = ['NUM_WORKERS', 'EPOCH_SAVE', 'EPOCH_SAVE_BACKUP', 'SHOW_LOSS_BACKUP']
for param_name, param in zip(other_param_names_list, other_param_list):
    print(f"{param_name}: {param}")
print()

Input script data 

Main parameters:
SAVE_DIR: 
DEVICE: cpu
MODEL_TYPE: AE
DATASET_TYPE: MNIST
ARCHITECTURE_TYPE: NIPS
BOTTLENECK: 8
EPOCHS: 5



All model parameters:
IN_FEATURES: 1024
BOTTLENECK: 8
OUT_FEATURES: 8192
DS_IN_CHANNELS: 1
C_H_W: [128, 8, 8]
MIDDLE_MATRIXES: 8
NONLINEARITY: ReLU()
N_BINS: 20
DROPOUT: 0.0
SAMPLING: gumbell
TEMP: 0.5


Training parameters:
BATCH_SIZE: 256
LEARNING_RATE: 0.0001
ALPHA: 0
EPOCHS: 5


Dataset parameters:
DATASET_TYPE: MNIST
TRAIN_SIZE: -1
TEST_SIZE: -1
BATCH_SIZE: 256


Other parameters
NUM_WORKERS: 32
EPOCH_SAVE: 25
EPOCH_SAVE_BACKUP: 5
SHOW_LOSS_BACKUP: 5



In [56]:
# Checking parameters
assert MODEL_TYPE.upper() in GOOD_MODEL_TYPE, f"Error, bad model type, select from: {GOOD_MODEL_TYPE}"
assert DATASET_TYPE.upper() in GOOD_DATASET_TYPE, f"Error, bad dataset type, select from: {GOOD_DATASET_TYPE}"
assert ARCHITECTURE_TYPE.upper() in GOOD_ARCHITECTURE_TYPE, f"Error, bad model architecture type, select from: {GOOD_ARCHITECTURE_TYPE}"
#############

## Dataset

In [57]:


####### Dataset
dataset_type = DATASET_TYPE
print('\n\n')
print(f'Loading dataset: {dataset_type}', flush=True)


train_ds, test_ds, ds_train_size, df_test_size, ds_in_channels = select_dataset(DATASET_TYPE, GOOD_DATASET_TYPE)
dl, dl_test = setup_dataset_training(train_ds, test_ds, BATCH_SIZE, num_workers=NUM_WORKERS)  
models_params['DS_IN_CHANNELS'] = ds_in_channels


TRAIN_SIZE = ds_train_size if TRAIN_SIZE == -1 else TRAIN_SIZE
TEST_SIZE = df_test_size if TEST_SIZE == -1 else TEST_SIZE
print("Dataset parameters:")
print(f"TRAIN_SIZE: {TRAIN_SIZE}({ds_train_size})")
print(f"TEST_SIZE: {TEST_SIZE}({df_test_size})")
for param, param_name in zip([BATCH_SIZE], ["BATCH_SIZE"] ):
    print(f"{param_name} = {param}")

print(f"{DATASET_TYPE} dataset logs:")
print("Img channel:", ds_in_channels)

print('\n\n')
###################




Loading dataset: MNIST
Dataset parameters:
TRAIN_SIZE: 60000(60000)
TEST_SIZE: 10000(10000)
BATCH_SIZE = 256
MNIST dataset logs:
Img channel: 1





## Initialization 

In [58]:

###################### Initialization of the model

# DEVICE, 

device = DEVICE
model_name = MODEL_NAME_PREF + f"{DATASET_TYPE}__{MODEL_TYPE}__{BOTTLENECK}__{ALPHA}"

print('\n\n')
print("Initialization of the model")
print("model_name: ", model_name, '\n\n' )

model = init_model(MODEL_TYPE, GOOD_MODEL_TYPE,  models_class_list, models_params, device)

print(model)

print(f"{MODEL_TYPE} was initialized")
PATH = os.path.join(SAVE_DIR, model_name)
print('Save PATH:', PATH, flush=True)
#################################





Initialization of the model
model_name:  Test_MNIST__AE__8__0 


ConvAE(
  (nonlinearity): ReLU()
  (down): Sequential(
    (0): Dropout(p=0, inplace=False)
    (1): DownsampleBlock(
      (conv): Conv2d(1, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (nonlinearity): ReLU()
    )
    (2): DownsampleBlock(
      (conv): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (nonlinearity): ReLU()
    )
    (3): DownsampleBlock(
      (conv): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (nonlinearity): ReLU()
    )
    (4): DownsampleBlock(
      (conv): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (nonlinearity): ReLU()
    )
  )
  (low_rank): InternalAutoencoder(
    (low_rank_pants): PantsAE(
      (out): Sequential(
        (0): Linear(in_features=1024, out_features=8, bias=True)
        (1): ReLU()
      )
    )
    (decoder): Sequential(
      (0): Linear(in_features=8, out_features=8192, b

## Training

In [59]:
####### Training
print("Training of the model", flush=True)
device = DEVICE


# setup training
# criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)


PATH = PATH
EPOCHS = EPOCHS


loss_list_train = []
loss_train_cum = 0

loss_list_test = []
loss_test_cum = 0
i = 0
loss = 0
epoch_0 = 0 

#time
epoch_time_list = []
epoch_t1 = None

alpha = ALPHA

epoch_save_backup = EPOCH_SAVE_BACKUP
epoch_save = EPOCH_SAVE
show_loss_backup = SHOW_LOSS_BACKUP


Training of the model


In [60]:

# load_path = None
load_path = 'Test_MNIST__AE__8__0__9__end.pth'
# loading chekpoint 
if load_path is not None:
    print('The training will be continue!!!')
    checkpoint = upload_checkpoint_in(load_path, model=model, optimizer=optimizer, device=device)
    print("Loaded epoch:", checkpoint['epoch'])
    print("Loaded final loss:", checkpoint['loss'])

    loss_list_train = checkpoint['loss_list_train']
    loss_list_test = checkpoint['loss_list_test']
    epoch_time_list = checkpoint['epoch_time_list'] if 'epoch_time_list' in checkpoint.keys() else None
    
    epoch_0 = checkpoint['epoch'] +1
    
    del checkpoint


The training will be continue!!!
Loaded epoch: 9
Loaded final loss: tensor(0.0498, device='cuda:2', requires_grad=True)


In [63]:
#### additional 
timer_start = timer()
# epoch_0 = epoch + 1


In [64]:

# Training
model.train()
optimizer.zero_grad()
torch.cuda.empty_cache()


for epoch in tqdm(range(EPOCHS)):
    epoch = epoch_0 + epoch
    # time
    epoch_t2 = timer()
    if epoch_t1 is not None:
        epoch_time_list += [epoch_t2 - epoch_t1]
    epoch_t1 = epoch_t2
    
    
    # Forward pass: Compute predicted y by passing x to the model
        
    # Training
    model.train() # Model to train
    epoch_t1 = timer()
    for x_batch, y_batch in dl:
        
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        
        # model forward
        # 2d downsampling
        x_down = model.down(x_batch)
        B, C, H, W = x_down.shape
        x_flat = x_down.view(B,C*H*W)
        
        encoded_out_dim, factors_probability = model.low_rank.low_rank_pants(x_flat)
        decoded_1d = model.low_rank.decoder(encoded_out_dim)
        
        # print(B, C, H, W )
        
        # 2d upsampling
        C, H, W = C_H_W
        decoded_2d_small = decoded_1d.view(B, C, H, W)
        decoded_2d = model.up(decoded_2d_small)
        

        
        
        # loss
        loss = main_loss(decoded_2d, x_batch)
        # loss = main_loss(decoded_2d.view(-1), x_batch.view(-1))
        if additional_loss is not None:
            loss += alpha*additional_loss(factors_probability)
        

            
        # Zero gradients, perform a backward pass, and update the weights.
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # accumulate loss
        loss_train_cum += loss.item()
        
        # validation and saving
        i += 1
        if i % 100 == 0:
            loss_list_train.append(loss_train_cum/100)
            loss_train_cum = 0
            with torch.no_grad():
                model.eval() # put to eval
                
                for x_batch, y_batch in dl_test:
                    # model forward
                    x_batch, y_batch = x_batch.to(device), y_batch.to(device)
                    x_decoded = model(x_batch)

                    loss_test = main_loss(x_decoded, x_batch)
                    loss_test_cum += loss_test.item()
                    
            assert torch.isnan(x_decoded).sum() == 0, f"Error! Nan values ({torch.isnan(x_decoded).sum()}) in models output"
      
            # save to list
            loss_list_test.append(loss_test_cum/len(dl_test))
            loss_test_cum = 0
          
    # backup saving  
    if epoch % epoch_save == 0:
        
        save_checkpoint_from(PATH + f"__{epoch}.pth", model, optimizer,  epoch=epoch, loss=loss.item(), 
                    loss_list_train=loss_list_train, loss_list_test=loss_list_test,
                    epoch_time_list=epoch_time_list)
        

        epoch_previous = epoch
            
    # backup saving  
    if epoch%epoch_save_backup == 0:
        
        save_checkpoint_from(PATH + f"__backup.pth", model, optimizer,  epoch=epoch, loss=loss.item(), 
                    loss_list_train=loss_list_train, loss_list_test=loss_list_test,
                    epoch_time_list=epoch_time_list)
        
        
  
        epoch_previous = epoch
      
    # loss printing        
    if (epoch % show_loss_backup == (show_loss_backup-1)) or (epoch == EPOCHS -1):
        fig = plt.figure(figsize=(6,3))
        plt.plot(loss_list_train, alpha=0.5, label='train')
        plt.plot(loss_list_test, alpha=0.5, label='test')
        plt.legend()
        plt.savefig( PATH  + "_loss.jpg")
        plt.close()
        pass
            

print("Finishing of the training...")

save_checkpoint_from(PATH + f"__{epoch}__end.pth", model, optimizer,  epoch=epoch, loss=loss.item(), 
                    loss_list_train=loss_list_train, loss_list_test=loss_list_test,
                    epoch_time_list=epoch_time_list)





#######################


timer_end = timer()
print(f"Elapsed time: {timer_end - timer_start:.2f} second") # Time in seconds, e.g. 5.38091952400282
print("Mean full epoch time:", f"{np.asarray(epoch_time_list).mean().item():.1f}")
print(f"Model training for {model_name} was successfully finished and saved!")
print('\n\n\n\n\n')


 40%|████      | 2/5 [11:38<17:28, 349.44s/it]


KeyboardInterrupt: 

In [21]:
epoch

8