In [1]:
import torch
import torch.nn as nn
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pyplot as plt
if torch.cuda.is_available():
    device = torch.device("cuda:0")  # you can continue going on here, like cuda:1 cuda:2....etc. 
    print("Running on the GPU")
else:
    device = torch.device("cpu")
    print("Running on the CPU")

Running on the GPU


### Load Data

In [3]:
trainset = datasets.MNIST('',download=True,train=True,transform=transforms.Compose([transforms.ToTensor(),
                                                                                  transforms.Normalize((0.5),(0.5))]))
train_loader = DataLoader(trainset,batch_size=50,shuffle=True)
testset = datasets.MNIST('',download=True,train=False,transform=transforms.Compose([transforms.ToTensor(),
                                                                                  transforms.Normalize((0.5),(0.5))]))
test_loader = DataLoader(testset,batch_size=50,shuffle=False)
train_mnist,val_mnist = train_test_split(trainset.data,test_size=0.2,random_state=42)

### Function for Processing Data

In [4]:
def make_mnist_domains(x,n_domains=4,n_missing_samples=100,missing_idxs=None):
    
    splits = np.split(torch.flatten(x,1,2),n_domains)
    masks = [torch.ones_like(split) for split in splits]
    
    if missing_idxs == None: 
        missing_idxs = [np.random.choice(split.shape[1],n_missing_samples) for split in splits]

    
    for idxs,split,mask in zip(missing_idxs,splits,masks):
        split[:,idxs] = 0
        mask[:,idxs] = 0
    
    return splits, masks, missing_idxs
    

### Get Splits and Masks

In [5]:
train_splits,train_masks,idxs = make_mnist_domains(train_mnist)
val_splits,val_masks,idxs = make_mnist_domains(val_mnist,missing_idxs=idxs)
test_splits,test_masks,idxs = make_mnist_domains(testset.data,missing_idxs=idxs)

In [13]:
train_splits[0].shape

torch.Size([12000, 784])

### Define the Model

In [6]:
class domain_style_encoders(nn.Module):
    def __init__(self,input_shape,n_domains,batch_size,device='cpu'):
        super(domain_style_encoders,self).__init__()
        
        self.input_shape = input_shape
        self.n_domains = n_domains
        self.batch_size = batch_size
        self.Encoder_List = nn.ModuleList().to(device)
        self.device = device
        
        for i in range(n_domains):
            self.Encoder_List.append(
                nn.Sequential(
                    nn.Linear(input_shape,input_shape),
                    nn.ELU(),
                    nn.Linear(input_shape,input_shape),
                    nn.ELU()             
                ).to(device)           
            )
            
    def forward(self,x):
        
        x_dse_full = torch.Tensor(self.n_domains*self.batch_size,self.input_shape).to(self.device)
        
        for i in range(self.n_domains):
            idxs = (i*self.batch_size,(i+1)*self.batch_size)
            x_dse = x[idxs[0]:idxs[1]] + self.Encoder_List[i].forward(x[idxs[0]:idxs[1]])
            x_dse_full[idxs[0]:idxs[1]] = x_dse

In [7]:
class domain_style_decoders(nn.Module):
    def __init__(self,input_shape,n_domains,batch_size,device='cpu'):
        super(domain_style_encoders,self).__init__()
        
        self.input_shape = input_shape
        self.n_domains = n_domains
        self.batch_size = batch_size
        self.Decoder_List = nn.ModuleList().to(device)
        self.device = device
        
        for i in range(n_domains):
            self.Decoder_List.append(
                nn.Sequential(
                    nn.Linear(input_shape,input_shape),
                    nn.ELU(),
                    nn.Linear(input_shape,input_shape),
                    nn.ELU()             
                ).to(device)           
            )
            
    def forward(self,x):
        
        x_dsd_full = torch.Tensor(self.n_domains*self.batch_size,self.input_shape).to(self.device)
        
        for i in range(self.n_domains):
            idxs = (i*self.batch_size,(i+1)*self.batch_size)
            x_dsd = x[idxs[0]:idxs[1]] + self.Decoder_List[i].forward(x[idxs[0]:idxs[1]])
            x_dsd_full[idxs[0]:idxs[1]] = x_dsd
        
        return x_dsd_full

In [8]:
class domain_style_adversaries(nn.Module):
    def __init__(self,input_shape,n_domains,batch_size,device='cpu'):
        super(domain_style_adversaries,self).__init__()
        
        self.input_shape = input_shape
        self.n_domains = n_domains
        self.batch_size = batch_size
        self.device = device
        
        self.Adversary_List = nn.ModuleList().to(device)
        
        for i in range(n_domains):
            self.Adversary_List.append(
                nn.Sequential(
                    #Add gradient reversal layer here?
                    nn.Linear(input_shape,50),
                    nn.ELU(),
                    nn.Linear(50,1),
                    nn.Sigmoid()
                ).to(device)
            )
            
    def forward(self,x):
        y_dsa_full = torch.Tensor(self.n_domains*self.batch_size,1).to(self.device)
        
        for i in range(self.n_domains):
            idxs = (i*self.batch_size,(i+1)*self.batch_size)
            y_pred_dsa = self.Adversary_List[i].forward(x[idxs[0]:idxs[1]])
            y_dsa_full[idxs[0]:idxs[1]] = y_pred_dsa
        
        return y_dsa_full

In [None]:
class latent_domain_adversary(nn.Module):
    def __init__(self,n_components,n_domains):
        super(latent_domain_adversary,self).__init__()
        
        self.n_components = n_components
        self.n_domains = n_domains
        
        self.Latent_Domain_Adversary = 
            nn.Sequential(
                nn.Linear(n_components,n_domains*2),
                nn.ELU(),
                nn.Linear(n_domains*2,n_domains),
                nn.Softmax()
            )
            
    def forward(self,x):
        return self.Latent_Domain_Adversary(x)

In [10]:
class shared_encoder_decoder(nn.Module):
    
    def __init__(self,input_shape,n_components):
        super(shared_encoder_decoder,self).__init__()
        
        self.se_hidden_1 = 100
        self.n_components = 50
        self.sd_hidden_1 = 100
        self.n_components = n_components
        
        self.shared_encoder = nn.Sequential(
            nn.Linear(input_shape,self.se_hidden_1),
            nn.ELU(),
            nn.Linear(self.se_hidden_1,self.n_components),
            nn.ELU()
        )
        
        self.se_mean = nn.Linear(self.n_components,self.n_components)
        self.se_logvar = nn.Linear(self.n_components,self.n_components)
        
        self.shared_decoder = nn.Sequential(
            nn.Linear(n_components,self.sd_hidden_1),
            nn.ELU(),
            nn.Linear(self.sd_hidden_1,input_shape),
            nn.ELU()
        )
    def Sampling(self,mean,log_var):
        eps = torch.randn(log_var.shape).to('cuda:0')
        sample = mean + torch.exp(log_var/2)*eps
        return sample
    
    def forward(self,x):
        
        x = self.shared_encoder(x)
        mean = self.se_mean(x)
        logvar = self.se_logvar(x)
        z = self.Sampling(mean,logvar)
        x_recon = self.shared_decoder(z)
        
        return (z,mean,logvar), x_recon
    

### Training

In [11]:
model = shared_encoder_decoder(784,n_components=25).to(device)
lossFunction = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)
num_epochs = 5

In [12]:
for epoch in range(num_epochs):
    loss_ = 0
    for images,labels in train_loader:
        images = images.reshape(-1,784).to(device)
        dist_tuple,image_recon = model(images)
        loss = lossFunction(image_recon,images)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        loss_ += loss.item()
    print("Epoch{}, Training loss:{}".format(epoch, loss_ / len(train_loader)))

Epoch0, Training loss:0.14060185118888816
Epoch1, Training loss:0.08534790980940064
Epoch2, Training loss:0.07328506947805484
Epoch3, Training loss:0.06496107386114697
Epoch4, Training loss:0.060066064186394215


In [None]:
idx=10
plt.imshow(image_recon[idx].cpu().detach().numpy().reshape(28,28))
plt.show()
plt.imshow(images[idx].cpu().detach().numpy().reshape(28,28).astype(np.float32))
plt.show()

In [151]:
np.mean((image_recon[idx].cpu().detach().numpy().reshape(28,28) - images[idx].cpu().detach().numpy().reshape(28,28))**2)

0.011492738

In [145]:
np.mean(images[idx].cpu().detach().numpy().reshape(28,28))

0.11883253

In [146]:
np.mean(image_recon[idx].cpu().detach().numpy().reshape(28,28))

0.112951905