In [1]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision 
from torchvision import datasets
from torchvision import transforms
from torch.autograd import Variable 
import numpy as np 
import tqdm 
from torch.utils import data 
import librosa
from scipy.io import wavfile
import json
import random
from scipy.signal import spectrogram

In [2]:
class Args(object):
    def __init__(self, name="Net", batch_size=4, test_batch_size=50,
                epochs=30, lr=1e-4, optimizer='Adam', momentum=0.9, weight_decay=0.0,
                seed=0, cuda = True):
        self.name = name;
        self.epochs = epochs
        self.batch_size = batch_size
        self.lr = lr
        self.weight_decay = weight_decay
        self.optimizer = optimizer
        self.momentum = momentum
        self.seed = seed
        self.cuda = cuda and torch.cuda.is_available()

In [11]:
# constants used for normalizing (in pixel coordinate)
com_traj_minx = -1476
com_traj_maxx = 1758
com_traj_miny = -703
com_traj_maxy = 1785

In [12]:

class Dataset(data.Dataset):
    
    def __init__(self, list_IDs, transform=None):
        #'Initialization'
        self.list_IDs = list_IDs
        self.transform = transform
        
    def __len__(self):
        return len(self.list_IDs)
    
    def __getitem__(self, index):
        # Generates one sample of data 
        ID = self.list_IDs[index]
        
        # Audio Feature Extraction
        filename = 'dataset_archive/audio/audio' + str(ID) + '.wav'
        srate, audio = wavfile.read(filename)
        audio = np.array(audio)
        
        for i in range(0,7):
            freq,t,phase = spectrogram(audio[:,i], nperseg=512 , noverlap=256, mode='phase')
            freq,t,mag = spectrogram(audio[:,i], nperseg=512, noverlap=256, mode='magnitude')
            if (i == 0):
                audio_phase = torch.from_numpy(phase).unsqueeze(2).float()
                audio_mag = torch.from_numpy(mag).unsqueeze(2).float()
            else:
                audio_phase = torch.cat((audio_phase, torch.from_numpy(phase).unsqueeze(2).float()),2)
                audio_mag = torch.cat((audio_mag, torch.from_numpy(mag).unsqueeze(2).float()),2)
        audio_img = torch.cat((audio_mag, audio_phase), 2)
        audio_img = audio_img.permute(1,0,2)
        
        complete_traj = np.load("dataset_archive/trajectory/traj"+str(ID)+".npy")
        complete_traj = (complete_traj[0] - complete_traj).astype(float)
        complete_traj[:,0] = (complete_traj[:,0] - com_traj_minx) / (com_traj_maxx - com_traj_minx)
        complete_traj[:,1] = (complete_traj[:,1] - com_traj_miny) / (com_traj_maxy - com_traj_miny)
    
        
        if complete_traj.shape[0] > 135:
            complete_traj = complete_traj[0:135,:]
        if complete_traj.shape[0] < 135:
            for i in range(0, 135 - complete_traj.shape[0]):
                complete_traj = np.append(complete_traj, complete_traj[-1].reshape(1,2), axis=0)    

        complete_traj = torch.from_numpy(complete_traj).type(torch.FloatTensor)

        return audio_img, complete_traj

In [13]:
# Read in index for train, validation, test
with open('partition.json') as f:
    partition = json.load(f)
args = Args()
params = {'batch_size': args.batch_size,
          'shuffle': True,
          'num_workers': 4}
# Generators
training_set = Dataset(partition['train'])
training_generator = data.DataLoader(training_set, **params)

validation_set = Dataset(partition['val'])
validation_generator = data.DataLoader(validation_set, **params)

In [14]:
class Encoder(nn.Module):
    """Encoder is part of both TrajectoryGenerator and
    TrajectoryDiscriminator"""
    def __init__(
        self, embedding_dim=64, h_dim=64, mlp_dim=1024, num_layers=1,
        dropout=0.0
    ):
        super(Encoder, self).__init__()

        self.mlp_dim = 1024
        self.h_dim = h_dim
        self.embedding_dim = embedding_dim
        self.num_layers = num_layers

        self.encoder = nn.GRU(
            embedding_dim, h_dim, num_layers, dropout=dropout, bidirectional=True
        )
        self.spatial_embedding = nn.Linear(64, embedding_dim)

    def init_hidden(self, batch):
        return torch.zeros(self.num_layers*2, batch, self.h_dim).cuda()

    def forward(self, obs_traj):
        """
        Inputs:
        - obs_traj: Tensor of shape (obs_len, batch, 2)
        """
        # Encode observed Trajectory
        batch = obs_traj.size(1)
        state_tuple = self.init_hidden(batch)
        output, state = self.encoder(obs_traj, state_tuple)
        final_h = state[0]
        return output, final_h
    

In [15]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.cn1 = nn.Conv3d(in_channels=1, out_channels=64, kernel_size=(5,5,14), padding=(1,1,3))
        self.batchNorm1 = nn.BatchNorm3d(64)
        self.pooling1 = nn.MaxPool3d(kernel_size = (1,8,1), stride=(1,4,1))
        self.cn2 = nn.Conv3d(in_channels=64, out_channels=64, kernel_size=(5,5,7), padding=(1,1,2))
        self.batchNorm2 = nn.BatchNorm3d(64)
        self.pooling2 = nn.MaxPool3d(kernel_size = (1,8,1), stride=(1,4,1))
        self.cn3 = nn.Conv3d(in_channels=64, out_channels=64, kernel_size=(3,3,5), padding=(1,1,1))
        self.batchNorm3 = nn.BatchNorm3d(64)
        self.pooling3 = nn.MaxPool3d(kernel_size = (1,4,1), stride=(1,2,1))
        self.cn4 = nn.Conv3d(in_channels=64, out_channels=64, kernel_size=(3,3,3), padding=(1,0,0))
        self.batchNorm4 = nn.BatchNorm3d(64)
        self.pooling4 = nn.MaxPool3d(kernel_size = (1,4,1), stride=(1,2,1))
        
        self.encoder = Encoder(
            embedding_dim=64,
            h_dim=64,
            num_layers=1,
            dropout=0
        )
        
        self.encoder2 = Encoder(
            embedding_dim=64,
            h_dim=64,
            num_layers=1,
            dropout=0
        )
        
        self.ln = nn.Linear(128, 64)
        self.spatial_trimming1 = nn.Linear(555, 256)
        self.spatial_trimming2 = nn.Linear(256, 135)
        
        self.traj_predictor1 = nn.Linear(128, 64)
        self.dropout = nn.Dropout(0.2)
        self.traj_predictor2 = nn.Linear(64, 16)
        self.dropout1 = nn.Dropout(0.1)
        self.traj_predictor3 = nn.Linear(16, 2)
        
    def forward(self, x):
        x = x.unsqueeze(1)
        x = F.relu(self.cn1(x))
        x = self.batchNorm1(x)
        x = self.pooling1(x)
        x = F.relu(self.cn2(x))
        x = self.batchNorm2(x)
        x = self.pooling2(x)
        x = F.relu(self.cn3(x))
        x = self.batchNorm3(x)
        x = self.pooling3(x)
        x = F.relu(self.cn4(x))
        x = self.batchNorm4(x)
        x = self.pooling4(x)
        x = x.squeeze(4)
        x = x.squeeze(3)
        x = x.permute(2,0,1)
        x, _ = self.encoder(x)
        x = self.ln(x)
        x = torch.tanh(x)
        x, _ = self.encoder2(x)
        x = torch.tanh(x)
        x = x.permute(1,2,0)
        x = F.relu(self.spatial_trimming1(x))
        x = self.spatial_trimming2(x)
        x = x.permute(0,2,1)
        x = F.relu(self.traj_predictor1(x))
        x = self.dropout(x)
        x = F.relu(self.traj_predictor2(x))
        x = self.dropout1(x)
        x = torch.sigmoid(self.traj_predictor3(x))
        return x[:,:,0], x[:,:,1]

In [None]:
model = Net()
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
if args.cuda:
    model = model.cuda()
train_losses, train_accs = [], []
val_losses, val_accs = [], []
min_test_loss = 1000.
for epoch in range(args.epochs):
    # Training
    print("epoch " + str(epoch))
    model.train()
    
    total_loss, total_acc = 0., 0.
    progress_bar = tqdm.tqdm(training_generator, desc='Training')
    
    for batch_idx, (data, target) in enumerate(progress_bar):
        total_loss, total_acc = 0., 0.
        data = data.float()
        target = target.float()
        
        if args.cuda:
            data,target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        output = model(data)
        loss_x = F.mse_loss(output[0], target[:,:,0], reduction="sum")
        loss_y = F.mse_loss(output[1], target[:,:,1], reduction="sum")
        loss = loss_x + loss_y
        loss.backward()
        optimizer.step()
        train_losses.append(loss.data.item())
        
        total_loss += loss.data
        
        progress_bar.clear()
        progress_bar.set_description(
            'Epoch: {} loss: {:.4f}'.format(
                epoch, total_loss / (batch_idx + 1)))
        progress_bar.refresh()
        
    # Validation
    model.eval()
    test_loss, acc = 0., 0.
    progress_bar = tqdm.tqdm(validation_generator, desc='Validation')
    with torch.no_grad():
        for data, target in progress_bar:
            data = data.float()
            target = target.float()
            if args.cuda:
                data,target = data.cuda(), target.cuda()
            data, target = Variable(data), Variable(target)
            data = data.float()
            output = model(data)
            loss_x = F.mse_loss(output[0], target[:,:,0], reduction="sum") 
            loss_y = F.mse_loss(output[1], target[:,:,1], reduction="sum")  # sum up batch loss
            test_loss += (loss_x + loss_y)
    test_loss /= len(partition['val'])
    val_losses.append(test_loss.item())
    if epoch % 10 == 0 and epoch > 0:
        args.lr /= 5
        for param_group in optimizer.param_groups:
            param_group['lr'] = args.lr
    progress_bar.clear()
    progress_bar.write(
        '\nEpoch: {} validation test results - Average val_loss: {:.4f}'.format(
            epoch, test_loss))
    if (test_loss < min_test_loss):
        torch.save(model, "audio_checkpoint.pkl")
        min_test_loss = test_loss

Training:   0%|          | 0/288 [00:00<?, ?it/s]

epoch 0


Epoch: 0 loss: 0.1090: 100%|██████████| 288/288 [05:34<00:00,  1.16s/it]
Validation: 100%|██████████| 29/29 [00:08<00:00,  3.37it/s]
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
Training:   0%|          | 0/288 [00:00<?, ?it/s]


Epoch: 0 validation test results - Average val_loss: 9.0470
epoch 1


Epoch: 1 loss: 0.1609:  50%|█████     | 145/288 [02:50<02:50,  1.19s/it]