In [1]:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import matplotlib

from tqdm import tqdm
%matplotlib inline
from torch.utils.data import Dataset, DataLoader
import torch
import torchvision

import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from tqdm import tqdm as tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [2]:
class Focus(nn.Module):
    def __init__(self,input_dims):
        super(Focus,self).__init__()
        self.input_dims = input_dims
        self.fc1 = nn.Linear(input_dims,1,bias=False)
    def forward(self,x):
        x = self.fc1(x)
        return x
        

In [3]:
class Classification(nn.Module):
    def __init__(self,input_dims,output_dims):
        super(Classification,self).__init__()
        self.input_dims = input_dims
        self.output_dims = output_dims
        self.fc1 = nn.Linear(input_dims,output_dims)
    def forward(self,x):
        
        x = self.fc1(x)
        if self.output_dims > 1:
            x = x
        else:
            x = torch.sigmoid(x)
        return x

In [4]:
def calculate_loss_focus(gamma,focus_output):
    #print(gamma.shape,focus_output.shape)
    log_outputs = torch.log(focus_output)
    
    loss_ = gamma*log_outputs
    
    #print(loss_.shape)
    loss_ = torch.sum(loss_,dim=1)
    #print(loss_.shape)
    #print(torch.sum(-loss_>1))
    
    loss_ = -torch.mean(loss_,dim=0)
    
    #print(loss_.shape,loss_)
    
    return loss_ 
    

In [5]:
def calculate_loss_classification(gamma,classification_output,label,criterion,n_patches):
    
    batch = label.size(0)
    classes = classification_output.size(2)
    #print(classification_output)
    label = label.repeat_interleave(n_patches)
    classification_output = classification_output.reshape((batch*n_patches,classes))
    loss_ = criterion(classification_output,label)
    
    loss_ = loss_.reshape((batch,n_patches))
    
    loss_ = gamma*loss_
    loss_ = torch.sum(loss_,dim=1)
    loss_ = torch.mean(loss_,dim=0)
    #print(loss_,loss_.shape)
    
    return loss_
    

In [6]:
def expectation_step(fc,cl,data,labels):
    batch= data.size(0)
    with torch.no_grad():
        outputs_f = F.softmax(fc(data),dim=1)
        #print("sds",cl(data).shape)
        outputs_g = F.softmax(cl(data),dim=2)
        #print(outputs_f.shape,outputs_g.shape,outputs_g[0])
        
    outputs_g = outputs_g[np.arange(batch),:,labels]
    #print(outputs_g.shape,outputs_g[0],outputs_f.shape)
    
    
    
    p_x_y_z = outputs_f[:,:,0]*outputs_g   #(1-outputs_g)
    #print(p_x_y_z[0])
    #print(torch.sum(p_x_y_z,dim=1,keepdims=True))
    
    normalized_p = p_x_y_z/torch.sum(p_x_y_z,dim=1,keepdims=True)
    #print(normalized_p)
    return normalized_p

In [7]:
def maximization_step(p_z,focus,classification,data,labels,focus_optimizer,classification_optimizer,Criterion):    
    
    patches = data.size(1)
    focus_optimizer.zero_grad()
    classification_optimizer.zero_grad()
    
    focus_outputs = F.softmax(focus(data),dim=1)[:,:,0]
    classification_outputs = classification(data) # classification returns output after sigmoid/softmax
    
    
    #print(focus_outputs,classification_outputs)
    
    loss_focus = calculate_loss_focus(p_z,focus_outputs)
    loss_classification = calculate_loss_classification(p_z,classification_outputs,
                                                        labels,Criterion,patches)
    
    print("Focus loss",loss_focus.item())
    print("Classification loss",loss_classification.item())
    loss_focus.backward() 
    loss_classification.backward()
    focus_optimizer.step()
    classification_optimizer.step()
    
    return focus,classification,focus_optimizer,classification_optimizer

In [8]:
data = torch.tensor([[[3.],[3.],[-1.]],[[3.],[+1.],[3.]]])
labels = torch.tensor([0,1])
data.shape,labels.shape

(torch.Size([2, 3, 1]), torch.Size([2]))

In [16]:
focus = Focus(1)
#print(focus.fc1.weight.data)
focus.fc1.weight.data = torch.tensor([[0.]])
classification = Classification(1,2)
#print(classification.fc1.weight.data)
classification.fc1.weight.data = torch.tensor([[0.1],[-0.1]])
classification.fc1.bias.data = torch.tensor([0.,0.])

Criterion = nn.CrossEntropyLoss(reduction="none") #nn.BCELoss(reduction="none")
focus_optimizer = optim.SGD(focus.parameters(), lr=0.5)
classification_optimizer = optim.SGD(classification.parameters(),lr=0.5)

for i in range(200):
    p_z = expectation_step(focus,classification,data,labels)
    #print(p_z.shape)
    focus,classification,focus_optimizer,classification_optimizer=maximization_step(p_z
                                                                                ,focus,classification,data,
                                                                                labels,focus_optimizer,
                                                                                classification_optimizer,
                                                                                Criterion)
    

Focus loss 1.0986123085021973
Classification loss 0.7376129627227783
Focus loss 1.103293538093567
Classification loss 0.6849931478500366
Focus loss 1.0956631898880005
Classification loss 0.6763752698898315
Focus loss 1.082407832145691
Classification loss 0.6508828401565552
Focus loss 1.0344674587249756
Classification loss 0.6073635816574097
Focus loss 0.92075514793396
Classification loss 0.5168488025665283
Focus loss 0.7539455890655518
Classification loss 0.3891769051551819
Focus loss 0.6281040906906128
Classification loss 0.28865480422973633
Focus loss 0.5569226741790771
Classification loss 0.2299833595752716
Focus loss 0.5112577080726624
Classification loss 0.19291278719902039
Focus loss 0.47777020931243896
Classification loss 0.1666490137577057
Focus loss 0.451302707195282
Classification loss 0.14668597280979156
Focus loss 0.4294686019420624
Classification loss 0.13086259365081787
Focus loss 0.41096049547195435
Classification loss 0.11797450482845306
Focus loss 0.39496952295303345
C

Focus loss 0.126972496509552
Classification loss 0.00586796086281538
Focus loss 0.12668296694755554
Classification loss 0.005834284704178572
Focus loss 0.12639571726322174
Classification loss 0.005801002029329538
Focus loss 0.12611067295074463
Classification loss 0.0057681649923324585
Focus loss 0.1258278340101242
Classification loss 0.0057356031611561775
Focus loss 0.1255471259355545
Classification loss 0.005703491624444723
Focus loss 0.12526853382587433
Classification loss 0.005671655759215355
Focus loss 0.12499206513166428
Classification loss 0.0056402115151286125
Focus loss 0.12471773475408554
Classification loss 0.005609102547168732
Focus loss 0.12444540858268738
Classification loss 0.005578329786658287
Focus loss 0.12417513877153397
Classification loss 0.00554788950830698
Focus loss 0.12390682101249695
Classification loss 0.005517784506082535
Focus loss 0.1236405298113823
Classification loss 0.005488013848662376
Focus loss 0.12337622046470642
Classification loss 0.005458523984998

In [17]:
for params in focus.parameters():
    print(params)
    

Parameter containing:
tensor([[-1.7961]], requires_grad=True)


In [18]:
for params in classification.parameters():
    print(params)

Parameter containing:
tensor([[-2.6065],
        [ 2.6065]], requires_grad=True)
Parameter containing:
tensor([ 0.0283, -0.0283], requires_grad=True)


In [19]:
# method 1
batch = data.size(0)
indexes = torch.argmax(F.softmax(focus(data),dim=1),dim=1)[:,0].numpy()
print("Focus True",(np.sum(indexes == fore_idx,axis=0).item()/len(fore_idx))*100)
outputs = F.softmax(classification(data[np.arange(batch),indexes,:]),dim=1)
prediction = torch.argmax(outputs,dim=1)
accuracy = (torch.sum(prediction == labels,dim=0)/len(labels) )*100
print("Accuracy", accuracy.item())

Focus True 100.0
Accuracy 100.0


In [20]:
# method 2
focus_output = F.softmax(focus(data),dim=1)
indexes = torch.argmax(F.softmax(focus(data),dim=1),dim=1)[:,0].numpy()
classification_output = F.softmax(classification(data),dim=2)
print("Focus True",(np.sum(indexes == fore_idx,axis=0).item()/len(fore_idx))*100)
prediction = torch.argmax(torch.sum(focus_output*classification_output,dim=1),dim=1)
accuracy = (torch.sum(prediction == labels,dim=0)/len(labels) )*100
print("Accuracy", accuracy.item())

Focus True 100.0
Accuracy 100.0
