In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader, random_split
from load_data import load_data
import torch.nn.functional as F

# from torch.utils.data.sampler import SubsetRandomSampler

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



In [16]:

class CapsNet(nn.Module):
    def __init__ (self,input, caps_num, caps_dims):  # input=20 caps_num = 32 caps_dims =8
        super(CapsNet, self).__init__()
        self.caps_num = caps_num
        self.caps_dims = caps_dims
        self.input= input
        self.b_ij = torch.zeros(8,8)
        self.W = nn.Parameter(torch.randn(self.caps_num*13*13*self.caps_dims,8*8))
        self.ConvLayer = nn.Sequential(
            nn.Conv2d(in_channels=1,out_channels=256, kernel_size=9, stride=2),  #in20 ker9x256 strd2=>out[(20-9)/2]+1
            nn.ReLU(),
            nn.Conv2d(in_channels=256,out_channels=128, kernel_size=13, stride=3),  
            nn.ReLU(),
            nn.Conv2d(in_channels=128,out_channels=64, kernel_size=8, stride=2),  
            nn.ReLU()
            
            )
       
    def squash (self, s):  # squnsh vector sup a 1 to vec inf a 1 (kinda normalisation)
        norme = torch.sum(torch.square(s))
        s= (norme/(1+norme))*(s/torch.sqrt(norme))
        return s 
    
    def routing(self, x):

        # create random weights 
        # self.W = torch.cat([self.W], dim=0)   # why x batch size ?

        # Transform inputs by weight matrix.
        y= torch.reshape(x,[1, self.caps_num*13*13*self.caps_dims])
        u_hat = torch.matmul(y, self.W)    # 10*16 ta3 DigiCaps Layer
        u_hat= torch.reshape(u_hat, [8,8])    # psk 3ndna 8 capsule et 3ndna 8 vecteur de sortie donc les C c 8x8
        
        # Initialize routing logits to zero.
        
        c_ij = F.softmax(self.b_ij)
        V=[]
        # c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)
       
        # Iterative routing
        for i in range(8):   # psk 9olna 8 vecteurs de sortie
            s_j = (c_ij[:][i] * u_hat).sum(dim=1, keepdim=True)   # ymultipliyi ga3 les vecteurs b des c w ysommihom
            v_j = self.squash(s_j)
            v_j1 = torch.cat([v_j] * 8, dim=1)
            u_vj1 = torch.matmul(u_hat, v_j1).mean(dim=0, keepdim=True)

            # Update b_ij (routing)
            self.b_ij[:][i] = self.b_ij[:][i] + u_vj1
            
            V.append(v_j)

        return V
    
        
    def margin_loss(self, input, target, size_average=True):
        batch_size = 139

        # ||vc|| from the paper.
        v_mag = torch.sqrt((input**2).sum(dim=2, keepdim=True))

        # Calculate left and right max() terms from equation 4 in the paper.
        zero = torch.zeros(1)
        m_plus = 0.9
        m_minus = 0.1
        max_l = torch.max(m_plus - v_mag, zero).view(batch_size, -1)**2
        max_r = torch.max(v_mag - m_minus, zero).view(batch_size, -1)**2

        # This is equation 4 from the paper.
        loss_lambda = 0.5
        T_c = target
        L_c = T_c * max_l + loss_lambda * (1.0 - T_c) * max_r
        L_c = L_c.sum(dim=1)

        if size_average:
            L_c = L_c.mean()

        return L_c
    
    def forward(self, data, caps_num=8 ,caps_dims=8):
        batch_size = 1
        x= self.ConvLayer(data)
        x= torch.reshape(x, [batch_size, caps_num, caps_dims, 13 , 13])   # CapsLayers
        #  x= torch.reshape(x,[batch_size , caps_num*13*13, caps_dims])   # we reshape brk l tableau fih ga3 les vecteurs de dimensions 8 
        x= self.routing(x)
        return x


In [3]:
learning_rate = 0.01
batch_size = 1
test_batch_size = 128
epochs = 10

# Stop training if loss goes below this threshold.
early_stop_loss = 0.0001

# Create and Load data
trans = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor()])

data = load_data(csv_file='MMU.csv', transformer= trans)

traind, testd = random_split(data, [139, 20])

train_loader = DataLoader(traind, batch_size , shuffle= True)
test_loader = DataLoader(testd, batch_size , shuffle= True)

In [4]:
print(next(iter(train_loader))[0][0].shape)   # checking the image dimension brk hna
print(len(train_loader.dataset))  # checking the number of images

torch.Size([1, 224, 224])
139


In [17]:
CapsNets = CapsNet(train_loader, 8, 8)
print(CapsNets)


CapsNet(
  (ConvLayer): Sequential(
    (0): Conv2d(1, 256, kernel_size=(9, 9), stride=(2, 2))
    (1): ReLU()
    (2): Conv2d(256, 128, kernel_size=(13, 13), stride=(3, 3))
    (3): ReLU()
    (4): Conv2d(128, 64, kernel_size=(8, 8), stride=(2, 2))
    (5): ReLU()
  )
)


In [19]:
optimizer = optim.Adam( CapsNets.parameters() , lr=learning_rate)
log_interval = 1
CapsNets.train()

for epoch in range(epochs):
    for batch_idx, (data, target) in enumerate(train_loader):

        data=data.to(device)
        target=target.to(device)
        
        optimizer.zero_grad()
        output = CapsNets(data)
        print(output)
        # loss = CapsNets.margin_loss(data, output, target)
        # loss.backward()
        # optimizer.step()

        # print(f'Epoch [{epoch}/{epochs}], Loss : {loss.item()}')
        
        # if batch_idx % log_interval == 0:
        #     print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        #         epoch,
        #         batch_idx * len(data),
        #         len(train_loader.dataset),
        #         100. * batch_idx / len(train_loader),
        #         loss.data[0]))

  c_ij = F.softmax(self.b_ij)


[tensor([[-0.5666],
        [-0.0697],
        [ 0.1370],
        [ 0.2868],
        [ 0.5369],
        [ 0.3559],
        [-0.1665],
        [ 0.2720]], grad_fn=<MulBackward0>), tensor([[-0.5666],
        [-0.0697],
        [ 0.1370],
        [ 0.2868],
        [ 0.5369],
        [ 0.3559],
        [-0.1665],
        [ 0.2720]], grad_fn=<MulBackward0>), tensor([[-0.5666],
        [-0.0697],
        [ 0.1370],
        [ 0.2868],
        [ 0.5369],
        [ 0.3559],
        [-0.1665],
        [ 0.2720]], grad_fn=<MulBackward0>), tensor([[-0.5666],
        [-0.0697],
        [ 0.1370],
        [ 0.2868],
        [ 0.5369],
        [ 0.3559],
        [-0.1665],
        [ 0.2720]], grad_fn=<MulBackward0>), tensor([[-0.5666],
        [-0.0697],
        [ 0.1370],
        [ 0.2868],
        [ 0.5369],
        [ 0.3559],
        [-0.1665],
        [ 0.2720]], grad_fn=<MulBackward0>), tensor([[-0.5666],
        [-0.0697],
        [ 0.1370],
        [ 0.2868],
        [ 0.5369],
        [ 0.35