In [None]:
!pip install transformers

In [3]:
import transformers
import torch.nn as nn

In [4]:
class MCAN(nn.Module):
    def __init__(self):
        super(MCAN, self).__init__()

        ## Frequency feature extractor
        self.conv1 = nn.Conv1d(64, 32, 3)
        self.conv2 = nn.Conv1d(32, 64, 3)
        self.conv3 = nn.Conv1d(64, 128, 3)
        self.conv4 = nn.Conv1d(64, 64, 1)
        self.conv5 = nn.Conv1d(64, 48, 1)
        self.conv6 = nn.Conv1d(64, 64, 1)
        self.conv7 = nn.Conv1d(64, 32, 1)
        self.conv8 = nn.Conv1d(48, 64, 3)
        self.conv9 = nn.Conv1d(64, 96, 3)
        self.conv10 = nn.Conv1d(96, 96, 3)
        self.conv11 = nn.Conv1d(64, 64, 1)
        
        ## Resize different modalities
        self.t_fc = nn.Linear(768, 256)

        self.s_fc = nn.Linear(1024, 256)

        self.f_fc = nn.Linear(122, 256)

        ## Co-Attention Block 1
        self.ca1 = nn.TransformerEncoderLayer(d_model=256, nhead=4, dim_feedforward=512)
        self.ca2 = nn.TransformerEncoderLayer(d_model=256, nhead=4, dim_feedforward=512)

        ## Co embed 1
        self.co_embed1 = nn.Linear(512, 256)

        ## Co-Attention Block 2
        self.ca3 = nn.TransformerEncoderLayer(d_model=256, nhead=4, dim_feedforward=512)
        self.ca4 = nn.TransformerEncoderLayer(d_model=256, nhead=4, dim_feedforward=512)

        ## Co embed 2
        self.co_embed2 = nn.Linear(512, 256)

        ## Fake news detector
        self.p_fc = nn.Linear(256, 35)
        self.p2_fc = nn.Linear(35, 2)

    def forward(self, x):
        return x

In [5]:
mcan = MCAN()

In [6]:
sum(p.numel() for p in mcan.parameters())

2981211

In [10]:
class HMCAN(nn.Module):
    def __init__(self):
        super(HMCAN, self).__init__()

        ## Contextual transformer 1
        self.ct1_siglemodal = nn.TransformerEncoderLayer(d_model=768, nhead=1)
        self.ct1_multlemodal = nn.TransformerEncoderLayer(d_model=768, nhead=1)

        # concat these two to get C_{TI} = 1536 dim

        ## Contextual transformer 2
        self.ct2_siglemodal = nn.TransformerEncoderLayer(d_model=768, nhead=1)
        self.ct2_multlemodal = nn.TransformerEncoderLayer(d_model=768, nhead=1)

        # concat these two to get C_{IT} = 1536 dim

        ## therefore C_i = \alpha * C_IT + \beta * C_TI == 1536 dim  
        ## second option assume to be 768 dim because of pooling and post dim

        ## Fake news detector
        self.fnd = nn.Linear(768, 2)

    def forward(self, x):
        return x

In [11]:
hmcan = HMCAN()

In [12]:
sum(p.numel() for p in hmcan.parameters())

22057474

In [7]:
class SpotFake(nn.Module):
    def __init__(self):
        super(SpotFake, self).__init__()

        ## Image feature resize1
        self.im1 = nn.Linear(4096, 2742)

        ## Image feature resize2
        self.im2 = nn.Linear(2742, 32)

        ## Text feature resize1
        self.t1 = nn.Linear(768, 768)

        ## Text feature resize2
        self.t2 = nn.Linear(768, 32)

        ## multimodal layer after concat
        self.mm1 = nn.Linear(64, 35)

        ## Fake news detection layer
        self.fnd = nn.Linear(35, 1)


    def forward(self, x):
        return x

In [8]:
spotfake = SpotFake()

In [9]:
sum(p.numel() for p in spotfake.parameters())

11939261