In [1]:
import torch
import torch.nn as nn

class TextHeatmapGatedClassifier(nn.Module):
    def __init__(self, heatmap_only=False, text_only=False, use_dropout=False):
        super(TextHeatmapGatedClassifier, self).__init__()
        # 
        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(0.2)
        self.softmax = nn.Softmax(dim=1)
        self.relu = nn.ReLU(inplace=True)
        self.heatmap_only = heatmap_only
        self.text_only = text_only
        self.batchnorm_visual = nn.BatchNorm1d(1000) # bn in visual fusion
        self.batchnorm_text = nn.BatchNorm1d(1000) # bn in language fusion

        # Map CNN
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, 10, padding=1),  # 64@216*216
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, stride=2),  # 64@108*108
            nn.Conv2d(64, 128, 7),
            nn.ReLU(inplace=True),  # 128@102*102
            nn.MaxPool2d(2, stride=2),  # 128@51*51
            
            nn.Conv2d(128, 128, 4),
            nn.ReLU(inplace=True),  # 128@48*48
            nn.MaxPool2d(2, stride=2),  # 128@24*24
        )

        # Visual liners
        self.visual_liner = nn.Sequential(
            nn.Linear(128*24*24, 4096),
            nn.BatchNorm1d(4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 1000),
            nn.BatchNorm1d(1000),
            nn.ReLU(inplace=True)
        )
        self.visual_fuse = nn.Sequential(
            nn.Linear(1000*2, 1000),
            nn.BatchNorm1d(1000),
            nn.ReLU(inplace=True)
        )
        self.hidden_visual = nn.Sequential(
            nn.Linear(1000, 300),
            nn.BatchNorm1d(300),
            nn.Tanh()
        )

        # Text liners
        self.text_liner = nn.Sequential(
            nn.Linear(768, 4096),
            nn.BatchNorm1d(4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 1000),
            nn.BatchNorm1d(1000),
            nn.ReLU(inplace=True)
        )
        self.text_fuse = nn.Sequential(
            nn.Linear(1000*2, 1000),
            nn.BatchNorm1d(1000),
            nn.ReLU(inplace=True)
        )
        self.hidden_text = nn.Sequential(
            nn.Linear(1000, 300),
            nn.BatchNorm1d(300),
            nn.Tanh()
        )
        
        # Gates
        self.visual_gate = nn.Sequential(
            nn.Linear(1000*2, 1000),
            nn.BatchNorm1d(1000),
            nn.ReLU(inplace=True),
            nn.Linear(1000,300),
            nn.Sigmoid()
        )
        self.text_gate = nn.Sequential(
            nn.Linear(1000*2, 1000),
            nn.BatchNorm1d(1000),
            nn.ReLU(inplace=True),
            nn.Linear(1000,300),
            nn.Sigmoid()
        )

        # output
        self.logits = nn.Sequential(
            nn.Linear(300, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

        # weight init
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)

    def visual_forward(self, x):
        x = self.conv(x)
        x = x.view(x.size()[0], -1)
        x = self.visual_liner(x)
        return x

    def visual_fusion(self, v1, v2):
        v = torch.cat((v1, v2), 1)
        v = self.visual_fuse(v)
        return v
    
    def text_forward(self, t):
        t = self.text_liner(t)
        return t
    
    def text_fusion(self, t1, t2):
        t = torch.cat((t1, t2), 1)
        t = self.text_fuse(t)
        return t

    def gate_calc(self, v, t):
        x = torch.cat((v,t), 1)
        g_v = self.visual_gate(x)
        g_t = self.text_gate(x)
        return g_v, g_t

    def forward(self, x1, x2, t1, t2):
        # merge text feature pairs
        t1 = self.text_forward(t1)
        t2 = self.text_forward(t2)
        t = self.text_fusion(t1, t2)
        h_t = self.hidden_text(t)
        if self.text_only:
            logits = self.logits(h_t)
            return logits
        
        # encode image pairs
        v1 = self.visual_forward(x1)
        v2 = self.visual_forward(x2)
        # fuse visual feature pairs
        v = self.visual_fusion(v1, v2)
        h_v = self.hidden_visual(v)
        # Map only
        if self.heatmap_only:
            logits = self.logits(h_v)
            return logits
        
        # fuse visual and text with gates
        visual_gate, text_gate = self.gate_calc(v,t)
        y = visual_gate*h_v + text_gate*h_t

        logits = self.logits(y)
        return logits

In [2]:
%cd /work/adapting-CLIP-VGPs/
from torch.utils.data import DataLoader
from utils.heatmap_data import VGPsHeatmapsDataset

GPU = 4

train_dataset = VGPsHeatmapsDataset(split='train')
train_dataset.image_idices = train_dataset.image_idices[:100000]

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=100,
    shuffle=False,
    num_workers=4,
    pin_memory=True
    # sampler=train_sampler
)


/work/adapting-CLIP-VGPs


In [3]:
model = TextHeatmapGatedClassifier().to(GPU)
model.train()

TextHeatmapGatedClassifier(
  (dropout): Dropout(p=0.2, inplace=False)
  (softmax): Softmax(dim=1)
  (relu): ReLU(inplace=True)
  (batchnorm_visual): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (batchnorm_text): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv): Sequential(
    (0): Conv2d(1, 64, kernel_size=(10, 10), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(7, 7), stride=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(128, 128, kernel_size=(4, 4), stride=(1, 1))
    (7): ReLU(inplace=True)
    (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (visual_liner): Sequential(
    (0): Linear(in_features=73728, out_features=4096, bias=True)
    (1): Bat

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim

# 損失関数の定義: pair wise loss
criterion = nn.BCEWithLogitsLoss()
# criterion = nn.CrossEntropyLoss()
# criterion = nn.CosineEmbeddingLoss()
# pairwise loss, contrastive los

# 最適化手法の定義
optimizer = optim.SGD(model.parameters(), lr=0.0005, momentum=0.9)

In [5]:
gpu = GPU
phase = 'train'

batch_iter = iter(train_loader)
batch = next(batch_iter)
image_paths = batch['img_idx']
left_text_ft = batch['left_text_emb']
right_text_ft = batch['right_text_emb']
left_heatmaps = batch['left_heatmap']
right_heatmaps = batch['right_heatmap']
labels = batch['label']

left_heatmaps = left_heatmaps.unsqueeze(1).to(gpu)
right_heatmaps = right_heatmaps.unsqueeze(1).to(gpu)
left_text_ft = left_text_ft.squeeze(1).float().to(gpu)
right_text_ft = right_text_ft.squeeze(1).float().to(gpu)
label_tensor = labels.float().unsqueeze(1).to(gpu)

epoch_loss = 0.0
epoch_TP = 0
epoch_FP = 0
epoch_FN = 0
epoch_TN = 0
processed = 0

In [6]:
optimizer.zero_grad()
with torch.set_grad_enabled(phase=='train'):
    logits = model(left_heatmaps, right_heatmaps, left_text_ft, right_text_ft)
    loss = criterion(logits, label_tensor)
    epoch_loss += loss.item() * len(image_paths)
    loss.backward()
    optimizer.step()

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


In [7]:
logits

tensor([[0.5050],
        [0.7447],
        [0.3901],
        [0.6819],
        [0.7844],
        [0.8535],
        [0.5290],
        [0.5518],
        [0.6048],
        [0.4380],
        [0.6518],
        [0.7480],
        [0.5290],
        [0.6048],
        [0.6518],
        [0.7438],
        [0.5165],
        [0.7763],
        [0.5064],
        [0.5403],
        [0.2950],
        [0.4063],
        [0.5064],
        [0.2950],
        [0.7507],
        [0.3790],
        [0.7400],
        [0.7680],
        [0.7864],
        [0.7680],
        [0.5914],
        [0.8248],
        [0.7042],
        [0.6505],
        [0.8614],
        [0.8878],
        [0.9425],
        [0.7544],
        [0.7235],
        [0.7025],
        [0.9151],
        [0.9374],
        [0.8788],
        [0.8541],
        [0.8936],
        [0.8551],
        [0.8268],
        [0.8605],
        [0.9533],
        [0.7799],
        [0.4781],
        [0.6868],
        [0.7415],
        [0.4876],
        [0.3816],
        [0

In [8]:
preds = (logits>0.5).float()

epoch_TP += ((preds.squeeze(1) == 1) & (label_tensor.squeeze(1) == 1)).float().sum().item()
epoch_FP += ((preds.squeeze(1) == 1) & (label_tensor.squeeze(1) == 0)).float().sum().item()
epoch_FN += ((preds.squeeze(1) == 0) & (label_tensor.squeeze(1) == 1)).float().sum().item()
epoch_TN += ((preds.squeeze(1) == 0) & (label_tensor.squeeze(1) == 0)).float().sum().item()  
epoch_prec = epoch_TP / (epoch_TP + epoch_FP) if (epoch_TP + epoch_FP) > 0 else 0
epoch_rec = epoch_TP / (epoch_TP + epoch_FN) if (epoch_TP + epoch_FN) > 0 else 0
epoch_f1 = 2 * epoch_prec * epoch_rec / (epoch_prec + epoch_rec) if (epoch_prec + epoch_rec) > 0 else 0

In [9]:
epoch_prec

0.1686746987951807

In [10]:
epoch_rec

0.9333333333333333

In [11]:
epoch_f1

0.2857142857142857

In [12]:
sum(preds==label_tensor)

tensor([30], device='cuda:4')