In [None]:
!pip install gdown
!pip install pytube
import os
# from os.path import exists
# if not exists('unbalanced_train_segments.csv'):
  # !curl http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/unbalanced_train_segments.csv -o unbalanced_train_segments.csv

!gdown 1Syt6DFnTMk0PolN76MtabSrGYj4_FVoL
!unzip -q code.zip
!wget -O data.zip https://www.dropbox.com/s/oipv2a03ro1l3f0/data_small.zip?dl=0
!unzip -q data.zip
os.rename('data_small', 'data')

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import notebook

import torchvision
from torchvision import transforms

import matplotlib.pyplot as plt
import numpy as np
import cv2

from dataset import AudioSet
import visualize as v
from config import get_config
cfg = get_config('config.ini')

print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)
# Detect if we have a GPU available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    print("Using the GPU!")
else:
    print("WARNING: Could not find GPU! Using CPU only. If you want to enable GPU, please to go Edit > Notebook Settings > Hardware Accelerator and select GPU.")

In [None]:
size = None
trainAS = AudioSet('train', cfg['SAVE_DIR'], VAL_RATIO=cfg['VAL_RATIO'], TEST_RATIO=cfg['TEST_RATIO'], FT_RATIO=cfg['FT_RATIO'], size=size)
valAS = AudioSet('val', cfg['SAVE_DIR'], VAL_RATIO=cfg['VAL_RATIO'], TEST_RATIO=cfg['TEST_RATIO'], FT_RATIO=cfg['FT_RATIO'], size=size)
testAS = AudioSet('test', cfg['SAVE_DIR'], VAL_RATIO=cfg['VAL_RATIO'], TEST_RATIO=cfg['TEST_RATIO'], FT_RATIO=cfg['FT_RATIO'], size=size)

In [None]:
for i in range(5):
  idx = np.random.randint(0, len(trainAS))
  v.viz_pair(trainAS, idx, t='train pair w augmentation')
  v.viz_clip(trainAS, idx, t='train ')
for i in range(5):
  idx = np.random.randint(0, len(valAS))
  v.viz_clip(valAS, idx, t='val ')
for i in range(5):
  idx = np.random.randint(0, len(testAS))
  v.viz_clip(testAS, idx, t='test ')
plt.show()

In [None]:
class VisionConvNet(nn.Module):
  def __init__(self):
    super(VisionConvNet, self).__init__()
    self.net = nn.Sequential(
        # conv1_1
        nn.Conv2d(3, 64, 3, 2, 1),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        # conv1_2
        nn.Conv2d(64, 64, 3, 1, 1),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        # pool1
        nn.MaxPool2d(2),
        # conv2_1
        nn.Conv2d(64, 128, 3, 1, 1),
        nn.BatchNorm2d(128),
        nn.ReLU(),
        # conv2_2 
        nn.Conv2d(128,128, 3, 1, 1),
        nn.BatchNorm2d(128),
        nn.ReLU(),
        # pool2
        nn.MaxPool2d(2),
        # conv3_1
        nn.Conv2d(128,256,3,1,1),
        nn.BatchNorm2d(256),
        nn.ReLU(),
        # conv3_2
        nn.Conv2d(256, 256, 3, 1, 1),
        nn.BatchNorm2d(256),
        nn.ReLU(),
        # pool3
        nn.MaxPool2d(2),
        # conv4_1
        nn.Conv2d(256,512,3,1,1),
        nn.BatchNorm2d(512),
        nn.ReLU(),
        # conv4_2
        nn.Conv2d(512,512,3,1,1),
        nn.BatchNorm2d(512),
        nn.ReLU()
    )
    
  def norm_init(self, mean, std):
    for m in self._modules:
        if isinstance(m, nn.Conv2d):
          m.weight.data.normal_(mean, std)
          m.bias.data.zero_()
  
  def forward(self, x):
    h = self.net(x)
    return h

class AudioConvNet(nn.Module):
  def __init__(self):
    super(AudioConvNet, self).__init__()
    
    self.net = nn.Sequential(
        # conv1_1
        nn.Conv2d(1, 64, 3, 2, 1),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        # conv1_2
        nn.Conv2d(64, 64, 3, 1, 1),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        # pool1
        nn.MaxPool2d(2),
        # conv2_1
        nn.Conv2d(64, 128, 3, 1, 1),
        nn.BatchNorm2d(128),
        nn.ReLU(),
        # conv2_2 
        nn.Conv2d(128,128, 3, 1, 1),
        nn.BatchNorm2d(128),
        nn.ReLU(),
        # pool2
        nn.MaxPool2d(2),
        # conv3_1
        nn.Conv2d(128,256,3,1,1),
        nn.BatchNorm2d(256),
        nn.ReLU(),
        # conv3_2
        nn.Conv2d(256, 256, 3, 1, 1),
        nn.BatchNorm2d(256),
        nn.ReLU(),
        # pool3
        nn.MaxPool2d(2),
        # conv4_1
        nn.Conv2d(256,512,3,1,1),
        nn.BatchNorm2d(512),
        nn.ReLU(),
        # conv4_2
        nn.Conv2d(512,512,3,1,1),
        nn.BatchNorm2d(512),
        nn.ReLU()
    )
    
  def norm_init(self, mean, std):
    for m in self._modules:
        if isinstance(m, nn.Conv2d):
          m.weight.data.normal_(mean, std)
          m.bias.data.zero_()
  
  def forward(self, x):
    h = self.net(x)
    return h

class AVOLNet(nn.Module):
  def __init__(self):
    super(AVOLNet, self).__init__()

    # image network
    self.visionNet = VisionConvNet()
    self.conv5 = nn.Conv2d(512, 128, 1)
    self.conv6 = nn.Conv2d(128, 128, 1)

    # audio network
    self.audioNet = AudioConvNet()
    self.pool4 = nn.AdaptiveMaxPool2d(1)
    self.fc1 = nn.Linear(512,128)
    self.fc2 = nn.Linear(128,128)

    # fusion network
    self.conv7 = nn.Conv2d(1,1,1)
    self.sigmoid = nn.Sigmoid()
    self.maxpool = nn.AdaptiveMaxPool2d(1)
    
  def norm_init(self, mean, std):
    for m in self._modules:
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
          m.weight.data.normal_(mean, std)
          m.bias.data.zero_()
    self.visionNet.norm_init(mean, std)
    self.audioNet.norm_init(mean, std)

  def forward(self, image, audio):

    # process image
    img = self.visionNet(image)
    img = self.conv5(img)
    img = self.conv6(img)
    # img = torch.reshape(img,(img.shape[0],-1,128))
    img = torch.permute(torch.reshape(img,(img.shape[0],128,-1)), (0,2,1))


    # process audio
    aud = self.audioNet(audio)
    aud = self.pool4(aud)
    aud = aud.squeeze(2).squeeze(2)
    aud = F.relu(self.fc1(aud))
    aud = self.fc2(aud)
    aud = torch.reshape(aud, (aud.shape[0], 128, -1))

    # fuse embeddings
    dot_product = img @ aud
    dot_product = torch.reshape(dot_product, (dot_product.shape[0],1,8,8))

    localisation = self.conv7(dot_product)
    localisation = self.sigmoid(localisation)
    corresponds = self.maxpool(localisation).squeeze()

    return corresponds, localisation.squeeze(1)

In [None]:
def train_AVOL(model, viz_loader, train_loader, val_loader, learning_rate = 5e-4, weight_decay = 1e-5, epochs = 200, device = "cpu", start_epoch = 0):
  
  # reduce learning rate if starting epoch is not zero
  learning_rate = 0.94**(start_epoch // 16)* learning_rate

  # initialise optimizer and lr scheduler
  optimizer = optim.Adam(model.parameters(),lr=learning_rate,weight_decay=weight_decay)
  lr_scheduler = optim.lr_scheduler.StepLR(optimizer,16,0.94) # reduces learning rate by 6% every 16 epochs

  # initialise loss criterion
  criterion = nn.BCELoss()

  # Keep track of loss for plotting.
  train_loss_history = []
  train_acc_history = []
  eval_loss_history = []
  eval_acc_history = []

  up = nn.Upsample(scale_factor=16)

  tp = {'left':False, 'right':False, 'labelleft':False, 'labelbottom':False, 'bottom':False}
  match_str = ['fake pair','real pair']
  for epoch in notebook.tqdm(range(start_epoch, epochs)):
    train_loss_epoch = 0.0
    train_acc_epoch = 0.0

    # train model 
    model.train()

    batch_num = 0
    for images, audios, labels, indx in train_loader:
      print('\rbatch: ' + str(batch_num), end='')
      batch_num += 1

      # send to device
      images, audios, labels = images.to(device), audios.to(device, dtype=torch.float), labels.to(device).squeeze()

      optimizer.zero_grad()
      output, local = model(images, audios)
      
      loss = criterion(output, labels)
      loss.backward()
      optimizer.step()

      with torch.no_grad():
        train_predict = torch.round(output)
    
      train_loss_epoch += loss.detach()
      train_acc_epoch += torch.sum(train_predict == labels).detach()
    
    if epoch % 1 == 0:
      it = iter(viz_loader)
      img,aud,lab,indx = next(it)
      img=img.to(device, dtype=torch.float)
      aud=aud.to(device, dtype=torch.float)
      output, local = model(img, aud)
      N = len(img)
      L = np.array(local.detach().cpu())
      loc = local.reshape(local.shape[0], 1, local.shape[1], local.shape[2])
      loc = loc.expand(loc.shape[0], 3, loc.shape[2], loc.shape[3])
      loc = up(loc)
      fig,ax = plt.subplots(nrows=2, ncols=N, figsize=(2*N, 4))
      img = np.transpose(np.array((img + loc).detach().cpu()), axes=[0,2,3,1])
      aud = np.squeeze(np.array(aud.detach().cpu()))
      for i in range(N):
        print('\r{}/{} visualisations'.format(i,N-1), end='')
        maxi = np.argmax(L[i])
        mini = np.argmin(L[i])
        img[i] -= np.min(img[i])
        ax[0][i].imshow(img[i]/np.max(img[i]))
        ax[1][i].imshow(aud[i])
        ax[0][i].set_title('{},{} {:.3f}\n{},{} {:.3f}'\
                          .format(mini//L.shape[1], mini%L.shape[1], np.min(L[i]),\
                                  maxi//L.shape[1], maxi%L.shape[1], np.max(L[i])), y=0, pad=7)
        ax[1][i].set_title('{} {:.0f} {:.0f} {:.0f} {:.0f}'.format(match_str[int(lab[i].item())], *list(indx[i])), y=0, pad=7)
        ax[0][i].tick_params('both', **tp)
        ax[1][i].tick_params('both', **tp)

      fig.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0, hspace=0)
      plt.savefig('save/epoch {} localisation.jpg'.format(epoch), dpi=300)
      plt.close(fig)

      img,aud,lab,indx = next(it)
      img=img.to(device, dtype=torch.float).requires_grad_(True)
      aud=aud.to(device, dtype=torch.float).requires_grad_(True)
      output, local = model(img,aud)
      local = torch.sum(local, dim=0)
      s = tuple(local.shape)
      fig,ax = plt.subplots(nrows=s[0], ncols=s[1], figsize=(s[0]*2, s[1]*2))
      for i in range(s[0]):
        for j in range(s[1]):
          print('\r{},{} RoIs'.format(i,j), end='')
          local[i,j].backward(retain_graph=True)
          grads = torch.sum(img.grad, dim=0)
          img.grad = None
          grads = np.transpose(grads.detach().cpu().numpy(), axes=[1,2,0])
          m = np.min(grads)
          M = np.max(grads)
          grads -= m
          grads /= np.max(grads)
          ax[i][j].imshow(grads)
          ax[i][j].set_title('{:.3f} to {:.3f}'.format(m, M), y=0, pad=7)
          ax[i][j].tick_params('both', **tp)

      fig.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0, hspace=0)
      plt.savefig('save/epoch {} RoI.jpg'.format(epoch), dpi=300)
      plt.close(fig)

    train_loss_history.append(train_loss_epoch)
    train_acc_epoch /= len(train_loader.dataset)
    train_acc_history.append(train_acc_epoch)
    
    # evaluate model 
    eval_loss_epoch = 0.0
    eval_acc_epoch = 0.0

    model.eval()
    for images, audios, labels, _ in val_loader:
      # send to device 
      images, audios, labels = images.to(device), audios.to(device, dtype=torch.float), labels.to(device).squeeze()

      with torch.no_grad():
        output, _ = model(images, audios)
        loss = criterion(output, labels)
        eval_predict = torch.round(output)

      eval_loss_epoch += loss.detach()
      eval_acc_epoch += torch.sum(eval_predict == labels).detach()

    eval_loss_history.append(eval_loss_epoch)
    eval_acc_epoch /= len(val_loader.dataset)
    eval_acc_history.append(eval_acc_epoch)

    # step scheduler at end of epoch
    lr_scheduler.step()

    # print loss values
    #if epoch % 10 == 0:
    if True:
      # Open filestream for saving log
      fileOut = open('save/training_log.txt', 'a')
      print('\r', end='')
      print('Epoch {}, training loss {:.3f}, train accuracy {}, validation loss {:.3f}, validation accuracy {}'.format(epoch, train_loss_epoch, train_acc_epoch, eval_loss_epoch, eval_acc_epoch))
      print('Epoch {}, training loss {:.3f}, train accuracy {}, validation loss {:.3f}, validation accuracy {}'.format(epoch, train_loss_epoch, train_acc_epoch, eval_loss_epoch, eval_acc_epoch),file=fileOut)
      torch.save(model.state_dict(), "save/AVOLNet{}.pt".format(epoch))
      fileOut.close()


  return train_loss_history, train_acc_history, eval_loss_history, eval_acc_history

def test_AVOL(model, test_loader, show_location = False):
  test_acc_epoch = 0

  model.eval()
  for images, audios, labels in test_loader:
    # send to device 
    images, audio, labels = images.to(device), audio.to(device), labels.to(device)

    with torch.no_grad():
      output, location = model(images, audio)
      test_predict = torch.round(output)
      test_acc_epoch += torch.sum(test_predict == labels)
      if (show_location):
        # 1 for true pair, 0 for false pair
        imgs = images[labels == 1 and test_predict == 1]
        locs = location[labels == 1 and test_predict == 1]
        for img, loc in zip(imgs,locs):
          output_img = visualize_localization(img,loc)
          plt.imshow(output_img)
  
  test_acc_epoch /= len(test_loader.dataset)

  print('Test accuracy {}'.format(test_acc_epoch))


In [None]:
def visualize_localization(image, loc):
  # inverse tranformation (may not be needed)

  inverse_norm = transforms.Compose([
            transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1 / 0.229, 1 / 0.224, 1 / 0.225]),
            transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1.0, 1.0, 1.0])])

  img = inverse_norm(image)

  # convert from RGB to BGR 
  new_img = torch.clone(img)
  new_img[0, :, :] = img[2, :, :]
  new_img[2, :, :] = img[0, :, :]
  img = new_img.cpu()

  # convert from (C x H x W) to (H x W x C)
  img = img.numpy().transpose((1, 2, 0))
  img = (img/img.max())*255
  img = img.astype(np.uint8)

  # convert location map into image
  loc = loc.cpu()
  loc = (loc * 255).numpy().astype(np.uint8)
  heatmap = cv2.applyColorMap(loc, cv2.COLORMAP_HOT)

  # resize heatmap to be same with img
  heatmap_upscaled = cv2.resize(heatmap, (img.shape[0], img.shape[1]), interpolation=cv2.INTER_AREA)

  # combine heatmap with image
  dst = cv2.addWeighted(img, 0.5, heatmap_upscaled, 0.5, 0)

  return cv2.cvtColor(dst, cv2.COLOR_BGR2RGB)

In [None]:
# train network 

# TO DO: 
train = trainAS
eval = valAS
train_loader = DataLoader(train,batch_size=64,shuffle=True,num_workers=2,pin_memory=True)
val_loader = DataLoader(eval,batch_size=64,shuffle=True,num_workers=2,pin_memory=True)

AVOL = AVOLNet().to(device)

# load existing weights (OPTIONAL)
weights_path = None
if weights_path is not None:
  AVOL.load_state_dict(torch.load(weights_path, map_location="cpu"))

tr_loss, tr_acc, val_losas, val_acc = train_AVOL(AVOL, train_loader, val_loader, learning_rate = 5e-4, epochs=100,device=device, start_epoch=0)

# visualize losses
plt.title("Training loss history")
plt.xlabel(f"Iteration (x 10)")
plt.ylabel("Loss")
plt.plot(tr_loss.cpu())
plt.show()

In [None]:
# test network

test = testAS

test_loader = DataLoader(test,batch_size=50,shuffle=True,num_workers=2,pin_memory=True)

AVOL = AVOLNet().to(device)

weights_path = "AVOLNet100.pt"
# if weights_path is not None:
AVOL.load_state_dict(torch.load(weights_path, map_location="cpu"))

test_AVOL(AVOL,test_loader, show_location=True)

In [None]:
# save network weights
torch.save(AVOL.state_dict(), "AVOLNet.pt")

In [None]:
# load network weights
AVOL = AVOLNet().to(device)
AVOL.load_state_dict(torch.load("AVOLNet.pt", map_location="cpu"))

In [None]:
#visualize receptive field
AOEloader = DataLoader(train,batch_size=8,shuffle=False,num_workers=2,pin_memory=True,sampler=np.random.permutation(len(trainAS)))
it = iter(AOEloader)
img,aud,lab,indx = next(it)
img=img.to(device, dtype=torch.float).requires_grad_(True)
aud=aud.to(device, dtype=torch.float).requires_grad_(True)
# newAVOL = AVOLNet().to(device)
# newAVOL.norm_init(0, 1e-3)
output, local = AVOL(img,aud)
local = torch.sum(local, dim=0)
s = tuple(local.shape)
print(s)
fig,ax = plt.subplots(nrows=s[0], ncols=s[1], figsize=(s[0]*2, s[1]*2))
tp = {'left':False, 'right':False, 'labelleft':False, 'labelbottom':False, 'bottom':False}
for i in range(s[0]):
  for j in range(s[1]):
    print('\r{},{}'.format(i,j), end='')
    local[i,j].backward(retain_graph=True)
    grads = torch.sum(img.grad, dim=0)
    img.grad = None
    grads = np.transpose(grads.detach().cpu().numpy(), axes=[1,2,0])
    m = np.min(grads)
    M = np.max(grads)
    grads -= m
    grads /= np.max(grads)
    ax[i][j].imshow(grads)
    ax[i][j].set_title('{:.3f} to {:.3f}'.format(m, M), y=0, pad=7)
    ax[i][j].tick_params('both', **tp)

fig.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0, hspace=0)
plt.show()

In [None]:
#debugging visualize receptive field.
class newAVOLNet(nn.Module):
  def __init__(self):
    super(newAVOLNet, self).__init__()

    # image network
    self.visionNet = VisionConvNet()
    self.conv5 = nn.Conv2d(512, 128, 1)
    self.conv6 = nn.Conv2d(128, 128, 1)

    # audio network
    self.audioNet = AudioConvNet()
    self.pool4 = nn.AdaptiveMaxPool2d(1)
    self.fc1 = nn.Linear(512,128)
    self.fc2 = nn.Linear(128,128)

    # fusion network
    self.conv7 = nn.Conv2d(1,1,1)
    self.sigmoid = nn.Sigmoid()
    self.maxpool = nn.AdaptiveMaxPool2d(1)
    
  def norm_init(self, mean, std):
    for m in self._modules:
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
          m.weight.data.normal_(mean, std)
          m.bias.data.zero_()
    self.visionNet.norm_init(mean, std)
    self.audioNet.norm_init(mean, std)

  def forward(self, image, audio):

    # process image
    img = self.visionNet(image)
    img = self.conv5(img)
    img = self.conv6(img)
    img = torch.permute(torch.reshape(img,(img.shape[0],128,-1)), (0,2,1))


    # process audio
    aud = self.audioNet(audio)
    aud = self.pool4(aud)
    aud = aud.squeeze(2).squeeze(2)
    aud = F.relu(self.fc1(aud))
    aud = self.fc2(aud)
    aud = torch.reshape(aud, (aud.shape[0], 128, -1))

    # fuse embeddings
    dot_product = img @ aud
    dot_product = torch.reshape(dot_product, (dot_product.shape[0],1,8,8))

    localisation = self.conv7(dot_product)
    localisation = self.sigmoid(localisation)
    corresponds = self.maxpool(localisation).squeeze()

    return corresponds, localisation.squeeze(1)

AOEloader = DataLoader(trainAS,batch_size=8,shuffle=False,num_workers=2,pin_memory=True,sampler=np.random.permutation(len(trainAS)))
it = iter(AOEloader)
img,aud,lab,indx = next(it)
img=img.to(device, dtype=torch.float).requires_grad_(True)
aud=aud.to(device, dtype=torch.float).requires_grad_(True)
newAVOL = newAVOLNet().to(device)
newAVOL.norm_init(0, 1e-3)
output, local = newAVOL(img,aud)
local = torch.sum(local, dim=0)
s = tuple(local.shape)
print(s)
fig,ax = plt.subplots(nrows=s[0], ncols=s[1], figsize=(s[0]*2, s[1]*2))
tp = {'left':False, 'right':False, 'labelleft':False, 'labelbottom':False, 'bottom':False}
for i in range(s[0]):
  for j in range(s[1]):
    print('\r{},{}'.format(i,j), end='')
    local[i,j].backward(retain_graph=True)
    grads = torch.sum(img.grad, dim=0)
    img.grad = None
    grads = np.transpose(grads.detach().cpu().numpy(), axes=[1,2,0])
    m = np.min(grads)
    M = np.max(grads)
    grads -= m
    grads /= np.max(grads)
    ax[i][j].imshow(grads)
    ax[i][j].set_title('{:.3f} to {:.3f}'.format(m, M), y=0, pad=7)
    ax[i][j].tick_params('both', **tp)

fig.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0, hspace=0)
plt.show()

In [None]:
#reduced channel netowrk

class VisionConvNet(nn.Module):
  def __init__(self):
    super(VisionConvNet, self).__init__()
    self.net = nn.Sequential(
        # conv1_1
        nn.Conv2d(3, 32, 3, 2, 1),
        nn.BatchNorm2d(32),
        nn.ReLU(),
        # conv1_2
        nn.Conv2d(32, 32, 3, 1, 1),
        nn.BatchNorm2d(32),
        nn.ReLU(),
        # pool1
        nn.MaxPool2d(2),
        # conv2_1
        nn.Conv2d(32, 64, 3, 1, 1),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        # conv2_2 
        nn.Conv2d(64, 64, 3, 1, 1),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        # pool2
        nn.MaxPool2d(2),
        # conv3_1
        nn.Conv2d(64,128,3,1,1),
        nn.BatchNorm2d(128),
        nn.ReLU(),
        # conv3_2
        nn.Conv2d(128, 128, 3, 1, 1),
        nn.BatchNorm2d(128),
        nn.ReLU(),
        # pool3
        nn.MaxPool2d(2),
        # conv4_1
        nn.Conv2d(128,256,3,1,1),
        nn.BatchNorm2d(256),
        nn.ReLU(),
        # conv4_2
        nn.Conv2d(256,256,3,1,1),
        nn.BatchNorm2d(256),
        nn.ReLU()
    )
    
  def norm_init(self, mean, std):
    for m in self._modules:
        if isinstance(m, nn.Conv2d):
          m.weight.data.normal_(mean, std)
          m.bias.data.zero_()
  
  def forward(self, x):
    h = self.net(x)
    return h

class AudioConvNet(nn.Module):
  def __init__(self):
    super(AudioConvNet, self).__init__()
    
    self.net = nn.Sequential(
        # conv1_1
        nn.Conv2d(1, 32, 3, 2, 1),
        nn.BatchNorm2d(32),
        nn.ReLU(),
        # conv1_2
        nn.Conv2d(32, 32, 3, 1, 1),
        nn.BatchNorm2d(32),
        nn.ReLU(),
        # pool1
        nn.MaxPool2d(2),
        # conv2_1
        nn.Conv2d(32, 64, 3, 1, 1),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        # conv2_2 
        nn.Conv2d(64,64, 3, 1, 1),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        # pool2
        nn.MaxPool2d(2),
        # conv3_1
        nn.Conv2d(64,128,3,1,1),
        nn.BatchNorm2d(128),
        nn.ReLU(),
        # conv3_2
        nn.Conv2d(128, 128, 3, 1, 1),
        nn.BatchNorm2d(128),
        nn.ReLU(),
        # pool3
        nn.MaxPool2d(2),
        # conv4_1
        nn.Conv2d(128,256,3,1,1),
        nn.BatchNorm2d(256),
        nn.ReLU(),
        # conv4_2
        nn.Conv2d(256,256,3,1,1),
        nn.BatchNorm2d(256),
        nn.ReLU()
    )
    
  def norm_init(self, mean, std):
    for m in self._modules:
        if isinstance(m, nn.Conv2d):
          m.weight.data.normal_(mean, std)
          m.bias.data.zero_()
  
  def forward(self, x):
    h = self.net(x)
    return h

class AVOLNet(nn.Module):
  def __init__(self):
    super(AVOLNet, self).__init__()

    # image network
    self.visionNet = VisionConvNet()
    self.conv5 = nn.Conv2d(256, 64, 1)
    self.conv6 = nn.Conv2d(64, 64, 1)

    # audio network
    self.audioNet = AudioConvNet()
    self.pool4 = nn.AdaptiveMaxPool2d(1)
    self.fc1 = nn.Linear(256,64)
    self.fc2 = nn.Linear(64,64)

    # fusion network
    self.conv7 = nn.Conv2d(1,1,1)
    self.sigmoid = nn.Sigmoid()
    self.maxpool = nn.AdaptiveMaxPool2d(1)
    
  def norm_init(self, mean, std):
    for m in self._modules:
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
          m.weight.data.normal_(mean, std)
          m.bias.data.zero_()
    self.visionNet.norm_init(mean, std)
    self.audioNet.norm_init(mean, std)

  def forward(self, image, audio):

    # process image
    img = self.visionNet(image)
    img = self.conv5(img)
    img = self.conv6(img)
    img = torch.reshape(img,(img.shape[0],64,-1))
    img = torch.permute(img,(0,2,1))

    # process audio
    aud = self.audioNet(audio)
    aud = self.pool4(aud)
    aud = aud.squeeze(2).squeeze(2)
    aud = F.relu(self.fc1(aud))
    aud = self.fc2(aud)
    aud = torch.reshape(aud, (aud.shape[0], 64, -1))

    # fuse embeddings
    dot_product = img @ aud
    dot_product = torch.reshape(dot_product, (dot_product.shape[0],1,8,8))

    localisation = self.conv7(dot_product)
    localisation = self.sigmoid(localisation)
    corresponds = self.maxpool(localisation).squeeze()

    return corresponds, localisation.squeeze(1)

In [None]:
#reduced layers network
class VisionConvNet(nn.Module):
  def __init__(self):
    super(VisionConvNet, self).__init__()
    self.net = nn.Sequential(
        # conv1_1
        nn.Conv2d(3, 64, 3, 2, 1),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        nn.MaxPool2d(2),
        # conv2_1
        nn.Conv2d(64, 128, 3, 1, 1),
        nn.BatchNorm2d(128),
        nn.ReLU(),
        # pool2
        nn.MaxPool2d(2),
        # conv3_1
        nn.Conv2d(128,256,3,1,1),
        nn.BatchNorm2d(256),
        nn.ReLU(),
        # pool3
        nn.MaxPool2d(2),
        # conv4_1
        nn.Conv2d(256,512,3,1,1),
        nn.BatchNorm2d(512),
        nn.ReLU(),
        # conv4_2
        nn.Conv2d(512,512,3,1,1),
        nn.BatchNorm2d(512),
        nn.ReLU()
    )
    
  def norm_init(self, mean, std):
    for m in self._modules:
        if isinstance(m, nn.Conv2d):
          m.weight.data.normal_(mean, std)
          m.bias.data.zero_()
  
  def forward(self, x):
    h = self.net(x)
    return h

class AudioConvNet(nn.Module):
  def __init__(self):
    super(AudioConvNet, self).__init__()
    
    self.net = nn.Sequential(
        # conv1_1
        nn.Conv2d(1, 64, 3, 2, 1),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        # conv1_2
        nn.Conv2d(64, 64, 3, 1, 1),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        # pool1
        nn.MaxPool2d(2),
        # conv2_1
        nn.Conv2d(64, 128, 3, 1, 1),
        nn.BatchNorm2d(128),
        nn.ReLU(),
        # conv2_2 
        # pool2
        nn.MaxPool2d(2),
        # conv3_1
        nn.Conv2d(128,256,3,1,1),
        nn.BatchNorm2d(256),
        nn.ReLU(),
        # pool3
        nn.MaxPool2d(2),
        # conv4_1
        nn.Conv2d(256,512,3,1,1),
        nn.BatchNorm2d(512),
        nn.ReLU(),
        # conv4_2
    )
    
  def norm_init(self, mean, std):
    for m in self._modules:
        if isinstance(m, nn.Conv2d):
          m.weight.data.normal_(mean, std)
          m.bias.data.zero_()
  
  def forward(self, x):
    h = self.net(x)
    return h

class AVOLNet(nn.Module):
  def __init__(self):
    super(AVOLNet, self).__init__()

    # image network
    self.visionNet = VisionConvNet()
    self.conv5 = nn.Conv2d(512, 128, 1)
    self.conv6 = nn.Conv2d(128, 128, 1)

    # audio network
    self.audioNet = AudioConvNet()
    self.pool4 = nn.AdaptiveMaxPool2d(1)
    self.fc1 = nn.Linear(512,128)
    self.fc2 = nn.Linear(128,128)

    # fusion network
    self.conv7 = nn.Conv2d(1,1,1)
    self.sigmoid = nn.Sigmoid()
    self.maxpool = nn.AdaptiveMaxPool2d(1)
    
  def norm_init(self, mean, std):
    for m in self._modules:
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
          m.weight.data.normal_(mean, std)
          m.bias.data.zero_()
    self.visionNet.norm_init(mean, std)
    self.audioNet.norm_init(mean, std)

  def forward(self, image, audio):

    # process image
    img = self.visionNet(image)
    img = self.conv5(img)
    img = self.conv6(img)
    # img = torch.reshape(img,(img.shape[0],-1,128))
    img = torch.permute(torch.reshape(img,(img.shape[0],128,-1)), (0,2,1))


    # process audio
    aud = self.audioNet(audio)
    aud = self.pool4(aud)
    aud = aud.squeeze(2).squeeze(2)
    aud = F.relu(self.fc1(aud))
    aud = self.fc2(aud)
    aud = torch.reshape(aud, (aud.shape[0], 128, -1))

    # fuse embeddings
    dot_product = img @ aud
    dot_product = torch.reshape(dot_product, (dot_product.shape[0],1,8,8))

    localisation = self.conv7(dot_product)
    localisation = self.sigmoid(localisation)
    corresponds = self.maxpool(localisation).squeeze()

    return corresponds, localisation.squeeze(1)