In [1]:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from tqdm import tqdm
%matplotlib inline
from torch.utils.data import Dataset, DataLoader
import torch
import torchvision

import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from tqdm import tqdm as tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [2]:
class Focus(nn.Module):
    def __init__(self,input_dims):
        super(Focus,self).__init__()
        self.input_dims = input_dims
        self.fc1 = nn.Linear(input_dims,1,bias=False)
    def forward(self,x):
        x = self.fc1(x)
        return x
        

In [3]:
class Classification(nn.Module):
    def __init__(self,input_dims,output_dims):
        super(Classification,self).__init__()
        self.input_dims = input_dims
        self.output_dims = output_dims
        self.fc1 = nn.Linear(input_dims,output_dims)
    def forward(self,x):
        
        x = self.fc1(x)
        if self.output_dims > 1:
            x = F.softmax(x,dim=1)
        else:
            x = torch.sigmoid(x)
        return x

In [4]:
def expectation_step(fc,cl,data):
    with torch.no_grad():
        outputs_f = F.softmax(fc(data),dim=1)
        outputs_g = cl(data)
    
    p_x_y_z = outputs_f*(1-outputs_g)
    #print(p_x_y_z)
    #print(torch.sum(p_x_y_z,dim=1,keepdims=True))
    
    normalized_p = p_x_y_z/torch.sum(p_x_y_z,dim=1,keepdims=True)
    #print(normalized_p)
    return normalized_p[:,:,0]

In [5]:
def maximization_step(p_z,focus,classification,data,labels,focus_optimizer,classification_optimizer,Criterion):
   
    torch.manual_seed(1234)
    Z = torch.multinomial(p_z,10,replacement=True)
    Z = Z.reshape((2*10))  # number of data points*number of samples
    data_repeat = data.repeat_interleave(torch.tensor([10]),dim=0)
    
    
    # classification module data
    X_Z = data_repeat[np.arange(20),Z] # batch size*number_of_z
    Y_Z = labels.repeat_interleave(torch.tensor([10]))
    #print(data_repeat.dtype,X_Z.dtype)
    
    
    # focus module data
    Y_fc = torch.zeros((10*2,3))
    Y_fc[np.arange(10*2),Z] = 1
    X_fc = data_repeat.reshape((20*3,1))
    Y_fc = Y_fc.reshape(20*3)
    
    
    focus_optimizer.zero_grad()
    classification_optimizer.zero_grad()
    
    focus_outputs = torch.sigmoid(focus(X_fc))
    classification_outputs = classification(X_Z) # classification returns output after sigmoid/softmax
    
    
    #print(focus_outputs,classification_outputs)
    
    
    loss_focus = Criterion(focus_outputs[:,0],Y_fc)
    loss_classification = Criterion(classification_outputs[:,0],Y_Z) 
    
    print("Focus loss",loss_focus.item())
    print("Classification loss",loss_classification.item())
    
    loss_focus.backward() 
    loss_classification.backward()
    
    focus_optimizer.step()
    classification_optimizer.step()
    
    return focus,classification,focus_optimizer,classification_optimizer

In [6]:
data = torch.tensor([[[3.],[3.],[-1.]],[[3.],[3.],[+1.]]])
labels = torch.tensor([0.,1.])
data.shape,labels.shape

(torch.Size([2, 3, 1]), torch.Size([2]))

In [19]:
focus = Focus(1)
focus.fc1.weight.data = torch.tensor([[0.]])
classification = Classification(1,1)
classification.fc1.weight.data = torch.tensor([[1.]])
classification.fc1.bias.data = torch.tensor([0.])

Criterion = nn.BCELoss()
focus_optimizer = optim.SGD(focus.parameters(), lr=0.1)
classification_optimizer = optim.SGD(classification.parameters(),lr=0.01)

for i in range(100):
    p_z = expectation_step(focus,classification,data)
    #print(p_z.shape)
    focus,classification,focus_optimizer,classification_optimizer=maximization_step(p_z
                                                                                ,focus,classification,data,
                                                                                labels,focus_optimizer,
                                                                                classification_optimizer,
                                                                                Criterion)
    
    

Focus loss 0.6931472420692444
Classification loss 0.4235605299472809
Focus loss 0.6239665150642395
Classification loss 0.42342910170555115
Focus loss 0.5641141533851624
Classification loss 0.28625819087028503
Focus loss 0.513045072555542
Classification loss 0.2988192141056061
Focus loss 0.47748252749443054
Classification loss 0.29813408851623535
Focus loss 0.4509064257144928
Classification loss 0.29745179414749146
Focus loss 0.4306432604789734
Classification loss 0.2967722713947296
Focus loss 0.4149051606655121
Classification loss 0.2960955500602722
Focus loss 0.40247687697410583
Classification loss 0.29542168974876404
Focus loss 0.3925171196460724
Classification loss 0.29475054144859314
Focus loss 0.38443121314048767
Classification loss 0.29408207535743713
Focus loss 0.3777911961078644
Classification loss 0.29341641068458557
Focus loss 0.3722832500934601
Classification loss 0.2927533984184265
Focus loss 0.36767318844795227
Classification loss 0.2920931279659271
Focus loss 0.3637839555

In [20]:
for params in focus.parameters():
    print(params)

Parameter containing:
tensor([[-1.0120]], requires_grad=True)


In [21]:
for params in classification.parameters():
    print(params)

Parameter containing:
tensor([[1.2384]], requires_grad=True)
Parameter containing:
tensor([-0.0073], requires_grad=True)


In [22]:
F.softmax(focus(data),dim=1)

tensor([[[0.0169],
         [0.0169],
         [0.9663]],

        [[0.1045],
         [0.1045],
         [0.7910]]], grad_fn=<SoftmaxBackward>)

In [23]:
focus(data)

tensor([[[-3.0359],
         [-3.0359],
         [ 1.0120]],

        [[-3.0359],
         [-3.0359],
         [-1.0120]]], grad_fn=<UnsafeViewBackward>)