In [2]:
import sys
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable


import matplotlib.pyplot as plt
import torch 
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm_notebook as tqdm
from torch.utils.data import Dataset, DataLoader
import h5py  
import numpy as np
import os 
from scipy.misc import imresize
import cv2
import random
import soundfile as sf

In [3]:
class AudioDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, train, frames_len=40, transform=None, h5_file='data/data.h5', transform_label=None):
        """
        Args:
            train (bool): Whether or not to use training data
            frames (int): Number of video frames per video sample
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.train = train
        self.transform = transform
        self.frames_len = frames_len
        
        dataset = h5py.File(h5_file)
        self.videos_train = np.array(dataset['videos_train'])
        self.sounds_train = np.array(dataset['sounds_train'])
        self.videos_test = np.array(dataset['videos_test'])
        self.sounds_test = np.array(dataset['sounds_test'])
        dataset.close()
        
    def __len__(self):
        if self.train:
            return len(self.videos_train)
        return len(self.videos_test)

    def __getitem__(self, idx):
        if self.train:
            image = self.videos_train[idx]
            audio = self.sounds_train[idx]
        else:
            image = self.videos_test[idx]
            audio = self.sounds_test[idx]

        # Randomly sample 4 seconds from 10 second clip
        start = random.randint(0, 100-self.frames_len) # Start frame
        new_image = np.zeros((self.frames_len,256,256,1), dtype=np.uint8)
        for i in range(self.frames_len):
            new_image[i] = np.expand_dims(image[start+i],2)
        
        # Randomly align or misalign audio sample
        if random.random() < 0.5: # align
            audio = audio[int(start*220500/100.0):int(start*220500/100.0)+88200]
            label = 0
        else: # misalign
            shift = random.randint(20, 60) # frame shift amount
            if random.random() < 0.5: # Add shift
                start = np.clip(start-shift, 0, 100-self.frames_len)
            else: # Subtract shift
                start = np.clip(start+shift, 0, 100-self.frames_len)
            audio = audio[int(start*220500/100.0):int(start*220500/100.0)+88200]
            label = 1
            
        transform_image = np.zeros((self.frames_len,1,224,224), dtype=np.uint8)
        if self.transform:
            for i in range(self.frames_len):
                transform_image[i] = self.transform(new_image[i]) # Transform image frames
            
        return (transform_image, audio, label)

In [4]:
import sys
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

class Block2(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, kernel_size, stride, downsample=None):
        super(Block2, self).__init__()
        self.out_channels = out_channels
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding=0, dilation=1, groups=1, bias=True)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=1, stride=1)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

class Block3(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, kernel_size=(1,1,1), stride=1, downsample=None, padding=0):
        super(Block3, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding=padding, dilation=1, groups=1, bias=True)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=(1,1,1), stride=1)
        self.bn2 = nn.BatchNorm3d(out_channels)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        
        if self.downsample is not None:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)

        return out

def Linear(in_features, out_features, dropout=0.):
    m = nn.Linear(in_features, out_features)
    m.weight.data.normal_(mean=0, std=math.sqrt((1 - dropout) / in_features))
    m.bias.data.zero_()
    return nn.utils.weight_norm(m)

class alignment(nn.Module):
    def __init__(self):
        super(alignment, self).__init__()
        """Sound Features"""
        self.conv1_1 = nn.Conv1d(2, 64, 65, stride=4, padding=0, dilation=1, groups=1, bias=True)
        self.pool1_1 = nn.MaxPool1d(4, stride=4)

        self.s_net_1 = self._make_layer(Block2, 64, 128, 15, 4, 1)
        self.s_net_2 = self._make_layer(Block2, 128, 128, 15, 4, 1)
        self.s_net_3 = self._make_layer(Block2, 128, 256, 15, 4, 1)
        
        self.pool1_2 = nn.MaxPool1d(3, stride=3)
        self.conv1_2 = nn.Conv1d(256, 128, 3, stride=1, padding=0, dilation=1, groups=1, bias=True)
        
        """Image Features"""
        self.conv3_1 = nn.Conv3d(1, 64, (5,7,7), (2,2,2), padding=(2,3,3), dilation=1, groups=1, bias=True)
        self.pool3_1 = nn.MaxPool3d((1,3,3), (1,2,2), padding=(0,1,1))
        self.im_net_1 = self._make_layer(Block3, 64, 64, (3,3,3), (2,2,2), 2)

        """Fuse Features"""
        self.fractional_maxpool = nn.FractionalMaxPool2d((3,1), output_size=(10, 1))
        self.conv3_2 = nn.Conv3d(192, 512, (1, 1, 1))
        self.conv3_3 = nn.Conv3d(512, 128, (1, 1, 1))
        self.joint_net_1 = self._make_layer(Block3, 128, 128, (3,3,3), (2,2,2), 2)
        self.joint_net_2 = self._make_layer(Block3, 128, 256, (3,3,3), (1,2,2), 2)
        self.joint_net_3 = self._make_layer(Block3, 256, 512, (3,3,3), (1,2,2), 2)

        #TODO: Global avg pooling, fc and sigmoid
        self.fc = Linear(512,2)

    def _make_layer(self, block, in_channels, out_channels, kernel_size, stride, blocks):
        downsample = None
        if stride != 1 or in_channels != out_channels * block.expansion:
            if isinstance(kernel_size, int):
                downsample = nn.Sequential(
                    nn.Conv1d(in_channels, out_channels * block.expansion, kernel_size, stride),
                    nn.BatchNorm1d(out_channels * block.expansion),
                )
                layers = []
                layers.append(block(in_channels, out_channels, kernel_size, stride, downsample))
            else:
                downsample = nn.Sequential(
                    nn.Conv3d(in_channels, out_channels * block.expansion, kernel_size, stride, padding=1),
                    nn.BatchNorm3d(out_channels * block.expansion),
                )
                layers = []
                layers.append(block(in_channels, out_channels, kernel_size, stride, downsample, padding=1))

        
        for _ in range(1, blocks):
            layers.append(block(out_channels, out_channels))

        return nn.Sequential(*layers)

    def forward(self, batchsize, sounds, images):
        sounds = sounds.view(batchsize, 2, -1)
        _, num, _, xd, yd, = images.shape
        images = images.view(batchsize, 1, num, xd, yd)
        
        out_s = self.conv1_1(sounds)
        out_s = self.pool1_1(out_s)

        out_s = self.s_net_1(out_s)
        out_s = self.s_net_2(out_s)
        out_s = self.s_net_3(out_s)

        out_s = self.pool1_2(out_s)
        out_s = self.conv1_2(out_s)
        
        out_im = self.conv3_1(images)
        out_im = self.pool3_1(out_im)
        out_im = self.im_net_1(out_im)

        #tile audio, concatenate channel wise
        out_s = self.fractional_maxpool(out_s.unsqueeze(3)) # Reduce dimension from 25 to 8
        out_s = out_s.squeeze(3).view(-1, 1, 1).repeat(1, 28, 28).view(-1,128,10,28,28) # Tile
        out_joint = torch.cat((out_s, out_im),1)
        out_joint = self.conv3_2(out_joint)
        out_joint = self.conv3_3(out_joint)
        out_joint = self.joint_net_1(out_joint)
        out_joint = self.joint_net_2(out_joint)
        out_joint = self.joint_net_3(out_joint)
        """Global Average Pooling"""
        out_joint = F.avg_pool3d(out_joint, kernel_size=out_joint.size()[2:]).view(batchsize,-1)
#         out_joint = out_joint.view(batchsize, 512, -1).mean(2)
        out_joint = self.fc(out_joint)
        out_joint = F.sigmoid(out_joint)
        return out_joint

In [17]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

transform = transforms.Compose([
transforms.ToPILImage(),
# transforms.RandomHorizontalFlip(),
transforms.RandomCrop(224),
transforms.ToTensor()])

train_dataset = AudioDataset(train=True,transform=transform)
test_dataset = AudioDataset(train=False,transform=transform)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=8, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=8, shuffle=False, num_workers=4)

model_align = alignment().cuda()

In [18]:
loss_fn = nn.CrossEntropyLoss()
optimizer_align = optim.Adam(model_align.parameters(), lr = 1e-5)
for epoch in range(500):
    accs = []
    losses = []
    model_align.train()
    for batch_idx, (images, sounds, labels) in enumerate(train_loader):
        images_v = Variable(images.type(torch.FloatTensor)).cuda()
        sounds_v = Variable(sounds.type(torch.FloatTensor)).cuda()
        labels_v = Variable(labels).cuda()
        
        optimizer_align.zero_grad()
        aligned_res = model_align(images.shape[0], sounds_v, images_v)
        loss = loss_fn(aligned_res, labels_v)
        loss.backward()
        optimizer_align.step()
        losses.append(loss.item())
        accs.append(np.mean((torch.argmax(aligned_res,1) == labels_v).detach().cpu().numpy()))
    print("Epoch :", epoch, np.mean(losses), np.mean(accs))
    if (epoch + 1)%25 == 0:
        accs = []
        losses = []
        model_align.eval()
        for batch_idx, (images, sounds, labels) in enumerate(test_loader):
            images_v = Variable(images.type(torch.FloatTensor)).cuda()
            sounds_v = Variable(sounds.type(torch.FloatTensor)).cuda()
            labels_v = Variable(labels).cuda()
            aligned_res = model_align(images.shape[0], sounds_v, images_v)
            loss = loss_fn(aligned_res, labels_v)
            losses.append(loss.item())
            accs.append(np.mean((torch.argmax(aligned_res,1) == labels_v).detach().cpu().numpy()))
        print("Validation :", epoch, np.mean(losses), np.mean(accs))

('Epoch :', 0, 0.70576853481764645, 0.46546391752577326)
('Epoch :', 1, 0.69721758549975366, 0.51649484536082468)
('Epoch :', 2, 0.69908098154461262, 0.4938144329896908)
('Epoch :', 3, 0.69515974066921116, 0.52164948453608251)
('Epoch :', 4, 0.6987021577726934, 0.48556701030927846)
('Epoch :', 5, 0.69420738871564569, 0.50051546391752577)
('Epoch :', 6, 0.69650955912993129, 0.49484536082474229)
('Epoch :', 7, 0.69271097109489832, 0.51855670103092777)
('Epoch :', 8, 0.69595667013188001, 0.46340206185567007)
('Epoch :', 9, 0.69009090022942454, 0.55360824742268033)
('Epoch :', 10, 0.6957782334888104, 0.53711340206185565)
('Epoch :', 11, 0.69741888820510545, 0.47989690721649481)
('Epoch :', 12, 0.69452309669907564, 0.49020618556701029)
('Epoch :', 13, 0.69561882854736956, 0.48041237113402063)
('Epoch :', 14, 0.69550693403814257, 0.49742268041237114)
('Epoch :', 15, 0.69327351788884584, 0.51237113402061862)
('Epoch :', 16, 0.6940771998818388, 0.52061855670103097)
('Epoch :', 17, 0.6943427735

In [20]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=8, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=8, shuffle=False, num_workers=4)

loss_fn = nn.CrossEntropyLoss()
optimizer_align = optim.Adam(model_align.parameters(), lr = 5e-6)
for epoch in range(500):
    accs = []
    losses = []
    model_align.train()
    for batch_idx, (images, sounds, labels) in enumerate(train_loader):
        images_v = Variable(images.type(torch.FloatTensor)).cuda()
        sounds_v = Variable(sounds.type(torch.FloatTensor)).cuda()
        labels_v = Variable(labels).cuda()
        
        optimizer_align.zero_grad()
        aligned_res = model_align(images.shape[0], sounds_v, images_v)
        loss = loss_fn(aligned_res, labels_v)
        loss.backward()
        optimizer_align.step()
        losses.append(loss.item())
        accs.append(np.mean((torch.argmax(aligned_res,1) == labels_v).detach().cpu().numpy()))
    print("Epoch :", epoch, np.mean(losses), np.mean(accs))
    if (epoch + 1)%25 == 0:
        accs = []
        losses = []
        model_align.eval()
        for batch_idx, (images, sounds, labels) in enumerate(test_loader):
            images_v = Variable(images.type(torch.FloatTensor)).cuda()
            sounds_v = Variable(sounds.type(torch.FloatTensor)).cuda()
            labels_v = Variable(labels).cuda()
            aligned_res = model_align(images.shape[0], sounds_v, images_v)
            loss = loss_fn(aligned_res, labels_v)
            losses.append(loss.item())
            accs.append(np.mean((torch.argmax(aligned_res,1) == labels_v).detach().cpu().numpy()))
        print("Validation :", epoch, np.mean(losses), np.mean(accs))

('Epoch :', 0, 0.61275880698297847, 0.66803278688524592)
('Epoch :', 1, 0.57513262600195214, 0.71311475409836067)
('Epoch :', 2, 0.602655645765242, 0.68647540983606559)
('Epoch :', 3, 0.56853482518039766, 0.71721311475409832)
('Epoch :', 4, 0.58947939755486656, 0.68237704918032782)
('Epoch :', 5, 0.58929789896871221, 0.70696721311475408)
('Epoch :', 6, 0.59230625482856247, 0.70081967213114749)
('Epoch :', 7, 0.58039294305394906, 0.68442622950819676)
('Epoch :', 8, 0.57267311655107089, 0.71721311475409832)
('Epoch :', 9, 0.55981989180455438, 0.72950819672131151)
('Epoch :', 10, 0.56855407406072145, 0.72336065573770492)
('Epoch :', 11, 0.57802983043623757, 0.7151639344262295)
('Epoch :', 12, 0.57632743872579983, 0.71311475409836067)
('Epoch :', 13, 0.56307173117262421, 0.71926229508196726)
('Epoch :', 14, 0.56574601374688693, 0.72131147540983609)
('Epoch :', 15, 0.56221031996070359, 0.72540983606557374)
('Epoch :', 16, 0.58829783807035352, 0.71311475409836067)
('Epoch :', 17, 0.591810861

('Epoch :', 137, 0.54554667824604475, 0.73975409836065575)
('Epoch :', 138, 0.53721777923771596, 0.77254098360655743)
('Epoch :', 139, 0.55662029932756896, 0.75409836065573765)
('Epoch :', 140, 0.52216583588084231, 0.78073770491803274)
('Epoch :', 141, 0.52081206001219205, 0.7848360655737705)
('Epoch :', 142, 0.51969486181853253, 0.77663934426229508)
('Epoch :', 143, 0.53233181111148142, 0.76229508196721307)
('Epoch :', 144, 0.53672955417242207, 0.76024590163934425)
('Epoch :', 145, 0.51390019059181213, 0.77868852459016391)
('Epoch :', 146, 0.53486165795169893, 0.76024590163934425)
('Epoch :', 147, 0.54369299431316187, 0.74590163934426235)
('Epoch :', 148, 0.55131172352149838, 0.72745901639344257)
('Epoch :', 149, 0.53374801698278207, 0.76639344262295084)
('Validation :', 149, 0.76718379901005673, 0.50206043956043955)
('Epoch :', 150, 0.54564322213657568, 0.75204918032786883)
('Epoch :', 151, 0.55401085585844323, 0.74590163934426235)
('Epoch :', 152, 0.53300759733700365, 0.772540983606

('Epoch :', 272, 0.56120503437323643, 0.72540983606557374)
('Epoch :', 273, 0.54945311683123232, 0.74795081967213117)
('Epoch :', 274, 0.56990640524958003, 0.73155737704918034)
('Validation :', 274, 0.77044473359218013, 0.50686813186813184)
('Epoch :', 275, 0.53401999346545481, 0.78073770491803274)
('Epoch :', 276, 0.52673317907286477, 0.77868852459016391)
('Epoch :', 277, 0.52010276278511425, 0.79303278688524592)
('Epoch :', 278, 0.50887426733970642, 0.80327868852459017)
('Epoch :', 279, 0.52095743916073789, 0.7848360655737705)
('Epoch :', 280, 0.51466566177665207, 0.79713114754098358)
('Epoch :', 281, 0.51463027029741004, 0.79713114754098358)
('Epoch :', 282, 0.52052267066767954, 0.79508196721311475)
('Epoch :', 283, 0.52313673593958865, 0.78073770491803274)
('Epoch :', 284, 0.51525451316208137, 0.78688524590163933)
('Epoch :', 285, 0.57648954967983435, 0.72336065573770492)
('Epoch :', 286, 0.51853262791868115, 0.77663934426229508)
('Epoch :', 287, 0.53605243954502169, 0.768442622950

('Epoch :', 405, 0.53049501923264053, 0.78073770491803274)
('Epoch :', 406, 0.49782596797239587, 0.80942622950819676)
('Epoch :', 407, 0.5261619799449796, 0.77049180327868849)
('Epoch :', 408, 0.49121683200851818, 0.80122950819672134)
('Epoch :', 409, 0.51686629701833253, 0.79303278688524592)
('Epoch :', 410, 0.52906868301454135, 0.78073770491803274)
('Epoch :', 411, 0.53781058456076947, 0.77254098360655743)
('Epoch :', 412, 0.54585489679555421, 0.76024590163934425)
('Epoch :', 413, 0.54684495583909454, 0.75819672131147542)
('Epoch :', 414, 0.53546686006374045, 0.77254098360655743)
('Epoch :', 415, 0.55377671953107488, 0.74180327868852458)
('Epoch :', 416, 0.53003335780784733, 0.77254098360655743)
('Epoch :', 417, 0.50529380118260614, 0.80532786885245899)
('Epoch :', 418, 0.53081881999969482, 0.77254098360655743)
('Epoch :', 419, 0.52584376100633967, 0.78073770491803274)
('Epoch :', 420, 0.54646620594087192, 0.75204918032786883)
('Epoch :', 421, 0.50143950151615457, 0.80737704918032782

In [21]:
torch.save(aligned_res, 'test_500.pth')

RuntimeError: cuda runtime error (4) : unspecified launch failure at /pytorch/torch/csrc/generic/serialization.cpp:17