In [7]:
import torch
import torch.nn as nn

class TextHeatmapFinal(nn.Module):
    '''
    Combine best text classifer with most effective map fusion
    '''
    def __init__(self, heatmap_only=False, text_only=False):
        super(TextHeatmapFinal, self).__init__()
        # Modal setting
        self.heatmap_only = heatmap_only
        self.text_only = text_only

        # Text Transform
        self.text_liner = nn.Sequential(
            nn.Linear(768, 4096),
            nn.ReLU(inplace=True)
        )
        self.text_fusion_liner = nn.Sequential(
            nn.Linear(4096, 1000),
            nn.ReLU(inplace=True)
        )
        self.hidden_text = nn.Sequential(
            nn.Linear(1000, 300),
            nn.BatchNorm1d(300),
        )

        self.text_logits = nn.Sequential(
            nn.Linear(4096, 1),
            nn.Sigmoid()
        )

        # Map CNN
        self.conv = nn.Sequential(
            nn.Conv2d(1, 8, 10, padding=1),  # 8@216*216
            nn.ReLU(inplace=True),
            nn.MaxPool2d(4, stride=4),  # 8@54*54
            nn.Conv2d(8, 16, 7),
            nn.ReLU(inplace=True),  # 8@48*48
            nn.MaxPool2d(4, stride=4),  # 16@12*12
        )

        # Visual liners
        self.visual_liner = nn.Sequential(
            nn.Linear(16*12*12, 1000),
            nn.BatchNorm1d(1000),
            nn.ReLU(inplace=True)
        )
        self.visual_fusion_liner = nn.Sequential(
            nn.Linear(1000*2, 1000),
            nn.BatchNorm1d(1000),
            nn.ReLU(inplace=True)
        )
        self.hidden_visual = nn.Sequential(
            nn.Linear(1000, 300),
            nn.BatchNorm1d(300),
        )

        self.map_logits = nn.Sequential(
            nn.Linear(300, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

        # Gates
        self.visual_gate = nn.Sequential(
            nn.Linear(1000*2, 1000),
            nn.BatchNorm1d(1000),
            nn.ReLU(inplace=True),
            nn.Linear(1000,300),
            nn.Sigmoid()
        )
        self.text_gate = nn.Sequential(
            nn.Linear(1000*2, 1000),
            nn.BatchNorm1d(1000),
            nn.ReLU(inplace=True),
            nn.Linear(1000,300),
            nn.Sigmoid()
        )
        
        # Output
        self.multimodal_logits = nn.Sequential(
            nn.Linear(300, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        
        # weight init
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)

    def visual_forward(self, x):
        x = self.conv(x)
        x = x.view(x.size()[0], -1)
        x = self.visual_liner(x)
        return x

    def visual_fuse(self, x1, x2):
        x1 = x1.view(x1.size()[0], -1)
        x1 = self.visual_liner(x1)

        x2 = x2.view(x2.size()[0], -1)
        x2 = self.visual_liner(x2)
        # Fuse
        x = torch.abs(x1-x2)
        return x
    
    def gate_calc(self, v, t):
        x = torch.cat((v,t), 1)
        g_v = self.visual_gate(x)
        g_t = self.text_gate(x)
        return g_v, g_t
    
    def forward(self, v1, v2, t1, t2):
        # Text
        t1 = self.text_liner(t1) # 768->4096
        t2 = self.text_liner(t2) # 768->4096
        t = torch.abs(t1-t2)
        if self.text_only:
            logits = self.text_logits(t) # 4906->1, sigmoid
            return logits
        t = self.text_fusion_liner(t) # 4096 -> 1000
        h_t = self.hidden_text(t) # 1000->300
        
        # Map
        v1 = self.visual_forward(v1) # 224*224->CNN->1000
        v2 = self.visual_forward(v2) # 224*224->CNN->1000
        v = torch.cat((v1, v2), 1) 
        v = self.visual_fusion_liner(v) # 2000->1000
        h_v = self.hidden_visual(v)
        if self.heatmap_only:
            logits = self.map_logits(v)
            return logits
    
        # Fusion
        g_v, g_t = self.gate_calc(v,t)
        tanh = nn.Tanh()
        y = g_v*tanh(h_v) + g_t*tanh(h_t)
        logits = self.multimodal_logits(y)
        return logits



In [2]:
%cd /work/adapting-CLIP-VGPs/
from torch.utils.data import DataLoader
from utils.heatmap_data import VGPsHeatmapsDataset

GPU = 0

train_dataset = VGPsHeatmapsDataset(split='train')
train_dataset.image_idices = train_dataset.image_idices[:200]

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=100,
    shuffle=False,
    num_workers=4,
    pin_memory=True
    # sampler=train_sampler
)


/work/adapting-CLIP-VGPs


In [8]:
model = TextHeatmapFinal().to(GPU)
model.train()

TextHeatmapFinal(
  (text_liner): Sequential(
    (0): Linear(in_features=768, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
  )
  (text_fusion_liner): Sequential(
    (0): Linear(in_features=4096, out_features=1000, bias=True)
    (1): ReLU(inplace=True)
  )
  (hidden_text): Sequential(
    (0): Linear(in_features=1000, out_features=300, bias=True)
    (1): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (text_logits): Sequential(
    (0): Linear(in_features=4096, out_features=1, bias=True)
    (1): Sigmoid()
  )
  (conv): Sequential(
    (0): Conv2d(1, 8, kernel_size=(10, 10), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(8, 16, kernel_size=(7, 7), stride=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  )
  (visual_liner): Sequential(
    (0): Linear(

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim

# 損失関数の定義: pair wise loss
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([6.5]).to(GPU))
# criterion = nn.CrossEntropyLoss()
# pairwise loss, contrastive los

# 最適化手法の定義
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

In [5]:
gpu = GPU
phase = 'train'

batch_iter = iter(train_loader)
batch = next(batch_iter)
image_paths = batch['img_idx']
left_text_ft = batch['left_text_emb']
right_text_ft = batch['right_text_emb']
left_heatmaps = batch['left_heatmap']
right_heatmaps = batch['right_heatmap']
labels = batch['label']

left_heatmaps = left_heatmaps.unsqueeze(1).to(gpu)
right_heatmaps = right_heatmaps.unsqueeze(1).to(gpu)
left_text_ft = left_text_ft.squeeze(1).float().to(gpu)
right_text_ft = right_text_ft.squeeze(1).float().to(gpu)
label_tensor = labels.float().unsqueeze(1).to(gpu)

epoch_loss = 0.0
epoch_TP = 0
epoch_FP = 0
epoch_FN = 0
epoch_TN = 0
processed = 0

In [7]:
left_text_ft

tensor([[ 0.1593,  0.5981, -0.1082,  ..., -0.2119,  0.1201, -0.5293],
        [ 0.1593,  0.5981, -0.1082,  ..., -0.2119,  0.1201, -0.5293],
        [ 0.4602, -0.6152,  0.4688,  ...,  0.0058, -0.0092,  0.4028],
        ...,
        [ 0.1316,  0.7441,  0.2695,  ..., -0.4397, -0.0232,  0.3755],
        [ 0.1316,  0.7441,  0.2695,  ..., -0.4397, -0.0232,  0.3755],
        [ 0.0035,  0.2289,  0.4023,  ..., -0.0433,  0.1329, -0.6191]],
       device='cuda:0')

In [9]:
optimizer.zero_grad()
with torch.set_grad_enabled(phase=='train'):
    logits = model(left_heatmaps, right_heatmaps, left_text_ft, right_text_ft)
    loss = criterion(logits, label_tensor)
    preds = (logits>0.5).float()
    # loss = criterion(logits, torch.squeeze(label_tensor.type(torch.long)))
    # _, preds = logits.max(dim=1)
    

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


In [10]:
loss

tensor(1.3647, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)

In [10]:
epoch_TP += ((preds.squeeze(1) == 1) & (label_tensor.squeeze(1) == 1)).float().sum().item()
epoch_FP += ((preds.squeeze(1) == 1) & (label_tensor.squeeze(1) == 0)).float().sum().item()
epoch_FN += ((preds.squeeze(1) == 0) & (label_tensor.squeeze(1) == 1)).float().sum().item()
epoch_TN += ((preds.squeeze(1) == 0) & (label_tensor.squeeze(1) == 0)).float().sum().item()  
epoch_prec = epoch_TP / (epoch_TP + epoch_FP) if (epoch_TP + epoch_FP) > 0 else 0
epoch_rec = epoch_TP / (epoch_TP + epoch_FN) if (epoch_TP + epoch_FN) > 0 else 0
epoch_f1 = 2 * epoch_prec * epoch_rec / (epoch_prec + epoch_rec) if (epoch_prec + epoch_rec) > 0 else 0

In [11]:
epoch_rec

0.4666666666666667

In [12]:
epoch_prec

0.13725490196078433

In [13]:
epoch_f1

0.21212121212121213