In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as np
import os

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    
class Mutation_MIL_MT(nn.Module):
    def __init__(self, in_features = 2048, act_func = 'tanh', drop_out = 0, n_outcomes = 7, dim_out = 128):
        super().__init__()
        self.in_features = in_features  
        self.L = in_features # 2048 node fully connected layer
        self.D = 128 # 128 node attention layer
        self.K = 1
        self.n_outs = n_outcomes # number of outcomes
        self.d_out = dim_out   # dim of output layers
        self.drop_out = drop_out

        if act_func == 'leakyrelu':
            self.act_func = nn.LeakyReLU()
        if act_func == 'tanh':
            self.act_func = nn.Tanh()
        elif act_func == 'relu':
            self.act_func = nn.ReLU()

        
        self.attention = nn.Sequential(
            nn.Linear(self.L, self.D),
            nn.Tanh(),
            nn.Linear(self.D, self.K),
            nn.Tanh()
        )

        # self.one_encoder = nn.Sequential()
        # for i in range(len(dim_list)-1):
        #     self.one_encoder.append(nn.Linear(dim_list[i], dim_list[i+1]))
        #     self.one_encoder.append(nn.ReLU(True))
        #     if i != (len(dim_list) - 2):
        #         self.one_encoder.append(nn.Dropout())
                
        self.embedding_layer = nn.Sequential(
            nn.Linear(self.in_features, 1024), #linear layer
            self.act_func,
            nn.Linear(1024, 512), #linear layer
            self.act_func,
            nn.Linear(512, 256), #linear layer
            self.act_func,
            nn.Linear(256, 128), #linear layer
        )

        #Outcome layers
        self.hidden_layers =  nn.ModuleList([nn.Linear(self.d_out, 1) for _ in range(self.n_outs)])        
        
        self.dropout = nn.Dropout(p=drop_out)

    def forward(self, x):
        r'''
        x size: [1, N_TILE ,N_FEATURE]
        '''
        #attention
        A = self.attention(x) # NxK
        A = F.softmax(A, dim=1) # softmax over N
        M = x*A
        x = M.sum(dim=1) #1, 2048

        
        # #Linear
        #x = self.embedding_layer(x) 

        # out = []
        # for i in range(len(self.hidden_layers)):
        #     cur_out = self.hidden_layers[i](x)
        #     out.append(cur_out)

        # #Drop out
        # if self.drop_out > 0:
        #     for i in range(len(self.hidden_layers)):
        #         out[i] = self.dropout(out[i])
        
        # # predict 
        # for i in range(len(self.hidden_layers)):
        #     out[i] = torch.sigmoid(out[i])
        
        return x , M, A


In [6]:
set_seed(0)
mod = Mutation_MIL_MT()
#print(mod)

# for param in mod.parameters():
#   print(param.data)

In [7]:
x  = torch.rand(1, 300, 2048, dtype = torch.float32) #[N-Tiles, Hidden_d]
print(x.shape)
x2, M, a = mod(x)
print(x2.shape)
print(M.shape)
print(a.shape)

torch.Size([1, 300, 2048])
torch.Size([1, 2048])
torch.Size([1, 300, 2048])
torch.Size([1, 300, 1])
