In [11]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as torch_data
import os
from datetime import datetime
import time
from IPython import display
from tqdm import tqdm

%matplotlib inline

In [12]:
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
torch.set_default_tensor_type(torch.FloatTensor)
torch.cuda.is_available()

True

In [13]:
class Data(torch.utils.data.Dataset):
    def __init__(self, X, y):
        super(Data, self).__init__()
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.int64)
    
    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [14]:
X_train = np.ones((2,1,256,256,256))
y_train = np.array([1,12])
X_val  = np.ones((1,1,256,256,256))
y_val = np.array([12])
train_dataset = Data(X_train, y_train)
val_dataset = Data(X_val, y_val)

batch_size = 2

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)  
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False) 


In [15]:
from AE_model import AE, Discriminator


In [16]:
def plot_loss(train_loss,train_loss_disc, val_loss, val_loss_disc):
    display.clear_output(wait=True)
    fig, ax = plt.subplots(2, 2, figsize=(10, 10))
    if train_loss:
        ax[0,0].semilogy(train_loss)
        ax[0,0].set_title('Training loss')
        ax[0,0].set_xlabel('# batches processed')
        ax[0,0].set_ylabel('loss value')
    if train_loss_disc:
        ax[0,1].semilogy(train_loss_disc)
        ax[0,1].set_title('Training discriminator loss')
        ax[0,1].set_xlabel('# batches processed')
        ax[0,1].set_ylabel('loss value')
    if val_loss:
        ax[1,0].semilogy(val_loss)
        ax[1,0].set_title('Validation  loss')
        ax[1,0].set_xlabel('# batches processed')
        ax[1,0].set_ylabel('loss value')
    if val_loss_disc:
        ax[1,1].semilogy(val_loss_disc)
        ax[1,1].set_title('Validation discriminator loss')
        ax[1,1].set_xlabel('# batches processed')
        ax[1,1].set_ylabel('loss value')
    
    plt.show()

In [17]:
def adv_loss(y, pred_logits):
    global n_domains
    global device
    y_onehot = torch.zeros((y.shape[0], n_domains), dtype=torch.int32)
    y_onehot.scatter_(1, y.view(-1,1), 1)
    y_reverse = (1-y_onehot).to(device)
    pred_prob = F.log_softmax(pred_logits,dim=1)
    return -torch.mean(torch.mul(y_reverse, pred_prob))
    del y_onehot, y_reverse, pred_prob

def main_loss(X_rec, X, rec_criterion_ae, 
              y, pred_logits, lambda_t):
    
    loss_rec = criterion_ae(X_rec, X) #MSE loss of reconstraction
    loss_adversarial = adv_loss(y, pred_logits)
    return loss_rec + lambda_t*loss_adversarial

In [18]:
def train(ae, disc, criterion_ae, criterion_disc, optimizer_ae, optimizer_disc,
        train_loader, val_loader,epochs, disc_loop=1,scheduler=None, exp_name=None,   save_dir=None):
    global device
    lambda_final = 1e-4
    lambda_initial = 0
    max_step = 500000
    lambda_step = (lambda_final-lambda_initial)/max_step
    step = 0
    lambda_t = lambda_initial
    
#     writer = SummaryWriter(f'logs/{exp_name}')
    train_loss = []
    train_loss_disc = []
    val_loss = []
    val_loss_disc = []
    
    for epoch in range(epochs):
        start_time = time.time()

        for batch_no, (X, y) in tqdm(enumerate(train_loader), total=len(train_loader)):
            '''
            X - ground truth image
            y - label of device, single scalar  
            '''
            X = X.to(device)
            
            ###train discrimintor
            ae.eval()
            for param in ae.parameters(): #fix encoder parameters to train disc
                param.requires_grad = False
            disc.train()
            for j in range(disc_loop): ### how in GAN do several iterations for discriminator?
                optimizer_disc.zero_grad()
                pred_logits = disc(ae.enc(X))
                loss_disc = criterion_disc(pred_logits, y.to(device))#Cross entropy loss, disc learn to pred real domain
                loss_disc.backward()
                optimizer_disc.step()
            for param in ae.parameters():
                param.requires_grad = True
            del pred_logits
            ###train AE
            ae.train()
            disc.eval()
            for param in disc.parameters(): # fix discriminator parameters to train autoencoder
                param.requires_grad = False
            
            optimizer_ae.zero_grad()
            X_rec = ae(X)
            pred_logits = disc(ae.enc(X))
            loss = main_loss(X_rec, X, criterion_ae, y, pred_logits, lambda_t)
            loss.backward()
            optimizer_ae.step()
            for param in disc.parameters():
                param.requires_grad = True
            del X_rec, pred_logits
            ### increase lambda
            step += 1
            if step < max_step:
                lambda_t += lambda_step
                
            ###Plot
            if batch_no % 10 == 0:
                train_loss.append(loss.item())
                train_loss_disc.append(loss_disc.item())
                plot_loss(train_loss, train_loss_disc,val_loss, val_loss_disc,)
                print(f'epoch {epoch} training stage...')
#                 writer.add_scalar('train loss', loss.item(), global_step=len(train_loss))
#                 writer.add_scalar('train clas loss', loss_disc.item(), global_step=len(train_loss))

        
        print(f'epoch {epoch} testing stage...')
        ae.eval()
        disc.eval()
        with torch.no_grad():
            for batch_no, (X, y) in tqdm(enumerate(val_loader), 
                                                     total=len(val_loader)):
               
                X = X.to(device)
                pred_logits = disc(ae.enc(X))
                loss_disc = criterion_disc(pred_logits, y.to(device))
                X_rec = ae(X)
                loss = main_loss(X_rec, X, criterion_ae, y, pred_logits, lambda_t)
                val_loss.append(loss.item())
                val_loss_disc.append(loss_disc.item())
            del X_rec, pred_logits
#                 writer.add_scalar('val loss', loss.item(), global_step=len(val_loss))
#                 writer.add_scalar('val clas loss', loss_disc.item(), global_step=len(val_loss))
    
                            

In [21]:
down_block_kwargs = {
    'conv_k': 3,
    'conv_pad': 1,
    'conv_s': 1,
    'maxpool_k': 2,
    'maxpool_s': 2,
    'batch_norm': True,
    'act': 'relu' ##'or l_relu'
}
up_block_kwargs = {
    'up': 'upsample',# or 'transpose_conv'
    'scale': 2,
    'scale_mode': 'nearest',
    'conv_k': 3,
    'conv_pad': 1,
    'conv_s': 1,
    'batch_norm': True,
    'act': 'relu' ##'or l_relu'
}
ae_kwargs ={
    'c_in':1,
    'is_skip': False,
    'deapth': 7,
    'c_base': 8,
    'inc_size':2,
    'reduce_size': False,
    'down_block_kwargs': down_block_kwargs,
    'up_block_kwargs': up_block_kwargs,
}
discriminator_kwargs = {
    'c_in': 512,
    'c_out':1024,
    'conv_k': 2,
    'conv_s': 2,
    'conv_pad': 0,
    'l_in': 1024,
    'l_out': 512,
    'batch_norm': True,
    'act': 'l_relu',
    'n_domains':20
    
}
n_domains = 20
ae = AE(**ae_kwargs)
disc = Discriminator(**discriminator_kwargs)

In [20]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

ae.to(device)
disc.to(device)

learning_rate_ae = 0.001
betas_ae = (0.9,0.999)
learning_rate_disc = 0.001
betas_disc = (0.9,0.999)

criterion_ae = nn.MSELoss()
criterion_disc = nn.CrossEntropyLoss()
optimizer_ae = torch.optim.Adam(ae.parameters(), lr=learning_rate_ae, betas=betas_ae)
optimizer_disc = torch.optim.Adam(disc.parameters(), lr=learning_rate_disc, betas=betas_disc)
train(ae, disc, criterion_ae, criterion_disc, optimizer_ae, optimizer_disc,
        train_loader, val_loader,epochs=2, disc_loop=1)

  0%|                                                                                            | 0/1 [00:00<?, ?it/s]


RuntimeError: CUDA out of memory. Tried to allocate 290.00 MiB (GPU 0; 4.00 GiB total capacity; 2.66 GiB already allocated; 149.35 MiB free; 119.53 MiB cached)

In [35]:
# c = 32
# n_domains = 20
# enc = Encoder(c = c)
# dec = Decoder(c = c)
# disc = Discriminator(n_domains = n_domains)
# device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [36]:
# enc.to(device)
# summary(enc, (1,256,256,256),device=device)

In [37]:
# dec.to(device)
# summary(dec, (256,4,4,4),device=device)

In [38]:
# disc.to(device)
# summary(disc, (256,4, 4, 4),device=device)