In [1]:
import sys
# Clear all imports from previous runs
sys.modules.pop('dataset', None)
sys.modules.pop('model', None)

import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
torch.set_default_tensor_type('torch.DoubleTensor')
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import random_split
from dataset import get_all_wind_directions
from dataset import TrajectoryDataset
from model import Phi_Net, H_Net_CrossEntropy, Model, save_model, load_model


  _C._set_default_tensor_type(t)


In [2]:
data_dir = Path('/Users/yefan/Desktop/CDS245/simple-quad-sim/data')
amplitudes = [0, 0.5, 1, 1.5, 2]
datasets = [TrajectoryDataset(data_dir, c) for c in amplitudes]


  self.data = pd.concat([self.data, pd.DataFrame({
100%|██████████| 125/125 [00:00<00:00, 297.17it/s]


X.shape: (500000, 11) Y.shape: (500000, 3) c.shape: (500000,)


  self.data = pd.concat([self.data, pd.DataFrame({
100%|██████████| 125/125 [00:00<00:00, 2065.44it/s]


X.shape: (500000, 11) Y.shape: (500000, 3) c.shape: (500000,)


  self.data = pd.concat([self.data, pd.DataFrame({
100%|██████████| 125/125 [00:00<00:00, 1678.84it/s]


X.shape: (500000, 11) Y.shape: (500000, 3) c.shape: (500000,)


  self.data = pd.concat([self.data, pd.DataFrame({
100%|██████████| 125/125 [00:00<00:00, 1645.56it/s]


X.shape: (500000, 11) Y.shape: (500000, 3) c.shape: (500000,)


  self.data = pd.concat([self.data, pd.DataFrame({
100%|██████████| 125/125 [00:00<00:00, 2096.43it/s]


X.shape: (500000, 11) Y.shape: (500000, 3) c.shape: (500000,)


In [3]:
omegas = [0, 0.5, 1, 1.5, 2]
wind_directions = get_all_wind_directions(data_dir)

In [12]:
options = {}
dim_x = 4 + 4 + 3
dim_y = 3
num_c = len(amplitudes)
dim_a = 3
learning_rate = 5e-4

print('dims of (x, y) are', (dim_x, dim_y))
print('there are ' + str(num_c) + ' different conditions')

dims of (x, y) are (11, 3)
there are 5 different conditions


In [13]:

Trainloader = []
Adaptloader = []
for i in range(len(amplitudes)):
    fullset = datasets[i]

    l = len(fullset)
    trainset, adaptset = random_split(fullset, [int(l*2/3), l-int(l*2/3)])

    trainloader = DataLoader(trainset, batch_size=256, shuffle=True)
    adaptloader = DataLoader(adaptset, batch_size=32, shuffle=True)

    Trainloader.append(trainloader)
    Adaptloader.append(adaptloader)


## Initialize the models

In [14]:
phi_net = Phi_Net(dim_x, dim_a)
h_net = H_Net_CrossEntropy(dim_a, num_c)
criterion = nn.MSELoss()
criterion_h = nn.CrossEntropyLoss()
optimizer_h = optim.Adam(h_net.parameters(), lr=learning_rate)
optimizer_phi = optim.Adam(phi_net.parameters(), lr=learning_rate)

In [17]:
model_save_freq = 50

Loss_f = [] # combined force prediction loss
Loss_c = [] # combined adversarial loss

alpha = 0.01
SN = 2

for epoch in range(1000):
    arr = np.arange(len(amplitudes))
    np.random.shuffle(arr)

    # Running loss over all subdatasets
    running_loss_f = 0.0
    running_loss_c = 0.0

    for i in arr:
        with torch.no_grad():
            adaptloader = Adaptloader[i]
            kshot_data = next(iter(adaptloader))
            trainloader = Trainloader[i]
            data = next(iter(trainloader))
        
        optimizer_phi.zero_grad()

        '''
        Least-square to get $a$ from K-shot data
        '''
        X = kshot_data[0]
        Y = kshot_data[1]
        phi = phi_net(X)
        phi_T = phi.transpose(0, 1)
        A = torch.inverse(torch.mm(phi_T, phi))
        a = torch.mm(torch.mm(A, phi_T), Y)

        if torch.norm(a, 'fro') > 10:
            a = a / torch.norm(a, 'fro') * 10
        
        '''
        Batch training \phi_net
        '''
        inputs = data[0]
        labels = data[1]

        c_labels = data[2].type(torch.long)

        # forward + backward + optimize
        outputs = torch.mm(phi_net(inputs), a)
        loss_f = criterion(outputs, labels)
        temp = phi_net(inputs)

        loss_c = criterion_h(h_net(temp), c_labels)

        loss_phi = loss_f - alpha * loss_c
        loss_phi.backward()
        optimizer_phi.step()

        '''
        Discriminator training
        '''
        if np.random.rand() <= 1.0 / 2:
            optimizer_h.zero_grad()
            temp = phi_net(inputs)
            
            loss_c = criterion_h(h_net(temp), c_labels)
            
            loss_h = loss_c
            loss_h.backward()
            optimizer_h.step()
        
        '''
        Spectral normalization
        '''
        if SN > 0:
            for param in phi_net.parameters():
                M = param.detach().numpy()
                if M.ndim > 1:
                    s = np.linalg.norm(M, 2)
                    if s > SN:
                        param.data = param / s * SN
        
        running_loss_f += loss_f.item()
        running_loss_c += loss_c.item()
    
    # Save statistics
    Loss_f.append(running_loss_f / len(amplitudes))
    Loss_c.append(running_loss_c / len(amplitudes))
    if epoch % 10 == 0:
        print('[%d] loss_f: %.2f loss_c: %.2f' % (epoch + 1, running_loss_f / len(amplitudes), running_loss_c / len(amplitudes)))

    if epoch % model_save_freq == 0:
        save_model(phi_net=phi_net, h_net=h_net, modelname=f'epoch{epoch}')

[1] loss_f: 5.39 loss_c: 1.06
[11] loss_f: 2.21 loss_c: 1.06
[21] loss_f: 5.06 loss_c: 1.06
[31] loss_f: 4.91 loss_c: 1.06
[41] loss_f: 6.43 loss_c: 1.06
[51] loss_f: 4.22 loss_c: 1.06
[61] loss_f: 6.60 loss_c: 1.06
[71] loss_f: 2.31 loss_c: 1.06
[81] loss_f: 3.13 loss_c: 1.06
[91] loss_f: 3.33 loss_c: 1.06
[101] loss_f: 5.36 loss_c: 1.06
[111] loss_f: 5.03 loss_c: 1.06
[121] loss_f: 1.80 loss_c: 1.06
[131] loss_f: 2.16 loss_c: 1.06
[141] loss_f: 3.36 loss_c: 1.06
[151] loss_f: 2.57 loss_c: 1.06
[161] loss_f: 1.80 loss_c: 1.06
[171] loss_f: 2.82 loss_c: 1.06
[181] loss_f: 2.42 loss_c: 1.06
[191] loss_f: 3.66 loss_c: 1.06
[201] loss_f: 4.27 loss_c: 1.06
[211] loss_f: 3.50 loss_c: 1.06
[221] loss_f: 4.02 loss_c: 1.06
[231] loss_f: 2.25 loss_c: 1.06
[241] loss_f: 3.45 loss_c: 1.06
[251] loss_f: 2.43 loss_c: 1.06
[261] loss_f: 2.59 loss_c: 1.06
[271] loss_f: 2.84 loss_c: 1.06
[281] loss_f: 2.89 loss_c: 1.06
[291] loss_f: 2.58 loss_c: 1.06
[301] loss_f: 1.33 loss_c: 1.06
[311] loss_f: 1.92 