In [1]:
import torch
from kwisatzHaderach import KwisatzHaderach
from Kaisarion import Kaisarion
import json
import os
import tqdm
import numpy as np
from datagen import *
from trainer import Trainer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def euclidean_distance(a, b):
    return torch.sqrt(torch.sum((a - b)**2, dim=-1) + 1e-12)

def mean_distance(a, b):
    return torch.mean(euclidean_distance(a, b))
'''
def loss_fn(pr_acc, gt_acc, num_neighbors):
    gamma = 0.5
    neighbor_scale = 1/100
    importance = torch.exp(neighbor_scale * num_neighbors) # removed minus sign to give more importance to particles with more neighbors
    importance = importance / torch.max(importance)
    euclidean_distances = euclidean_distance(pr_acc, gt_acc)
    if importance.size()[0] == 0:
        importance = 1.0
    return torch.mean(importance *
                        euclidean_distances**gamma), torch.mean(euclidean_distances)
'''

'\ndef loss_fn(pr_acc, gt_acc, num_neighbors):\n    gamma = 0.5\n    neighbor_scale = 1/100\n    importance = torch.exp(neighbor_scale * num_neighbors) # removed minus sign to give more importance to particles with more neighbors\n    importance = importance / torch.max(importance)\n    euclidean_distances = euclidean_distance(pr_acc, gt_acc)\n    if importance.size()[0] == 0:\n        importance = 1.0\n    return torch.mean(importance *\n                        euclidean_distances**gamma), torch.mean(euclidean_distances)\n'

In [3]:
def loss_fn(pr_acc, gt_acc, num_neighbors, gt_pos, pr_pos, pos_importance=1.0, gamma=1.0):
    neighbor_scale = 1/100
    importance = torch.exp(neighbor_scale * num_neighbors) # removed minus sign to give more importance to particles with more neighbors
    importance = importance / torch.max(importance)
    euclidean_distances = euclidean_distance(pr_acc, gt_acc)
    if importance.size()[0] == 0:
        importance = 1.0
    if gamma != 1.0:
        return torch.mean(importance *
                        euclidean_distances**gamma) + pos_importance * torch.mean(importance * euclidean_distance(pr_pos, gt_pos)**gamma), torch.mean(euclidean_distances)
    return torch.mean(importance *
                        euclidean_distances) + pos_importance * torch.mean(importance * euclidean_distance(pr_pos, gt_pos)), torch.mean(euclidean_distances)

In [4]:
def get_new_pos_vel(acc, pos, vel, dt=0.01):
    new_vel = vel + acc * dt 
    new_pos = pos + new_vel * dt
    return new_pos, new_vel

In [7]:


def train_epoch(model, data, batch_size, loss_fn, optimizer, device, use_custom_loss=False):

    model.train()

    num_batches = len(data) // batch_size
    all_losses = []
    all_dists = []
    for i in tqdm.tqdm(range(num_batches)):
        batch = data[i*batch_size:(i+1)*batch_size]
        m = torch.tensor([b['masses'] for b in batch], dtype=torch.float32).to(device)
        pos0 = torch.tensor([b['pos'] for b in batch], dtype=torch.float32).to(device)
        vel0 = torch.tensor([b['vel'] for b in batch], dtype=torch.float32).to(device)
        acc0 = torch.tensor([b['acc'] for b in batch], dtype=torch.float32).to(device)
        acc1 = torch.tensor([b['acc_next1'] for b in batch], dtype=torch.float32).to(device)
        if use_custom_loss:
            pos1 = torch.tensor([b['pos_next1'] for b in batch], dtype=torch.float32).to(device)
            pos2 = torch.tensor([b['pos_next2'] for b in batch], dtype=torch.float32).to(device)
        #acc2 = torch.tensor([b['acc_next2'] for b in batch], dtype=torch.float32).to(device)

        optimizer.zero_grad()
        losses = []
        for j in range(len(batch)):
            l = 0
            sample_masses = m[j].unsqueeze(1)
            sample_pos0 = pos0[j]
            sample_vel0 = vel0[j]
            sample_acc0 = acc0[j]
            sample_acc1 = acc1[j]
            if use_custom_loss:
                sample_pos1 = pos1[j]
                sample_pos2 = pos2[j]

           

            pr_acc0 = model(sample_pos0, sample_vel0, sample_masses)
            pr_pos1, pr_vel1 = get_new_pos_vel(pr_acc0, sample_pos0, sample_vel0)

            if use_custom_loss:
                loss0, dists0 = loss_fn(pr_acc0, sample_acc0, model.num_neighbors, pr_pos1, sample_pos1)
            else:
                loss0 = torch.mean(euclidean_distance(pr_acc0, sample_acc0))
                dists0 = loss0

            all_dists.append(dists0)

            l += loss0 * 0.5

            

            pr_acc1 = model(pr_pos1, pr_vel1, sample_masses)
            pr_pos2, pr_vel2 = get_new_pos_vel(pr_acc1, pr_pos1, pr_vel1)

            if use_custom_loss:
                loss1, dists1 = loss_fn(pr_acc1, sample_acc1, model.num_neighbors, pr_pos2, sample_pos2)
            else:
                loss1 = torch.mean(euclidean_distance(pr_acc1, sample_acc1))
                dists1 = loss1

            all_dists.append(dists1.item())

            l += loss1 * 0.5

            #pr_pos2, pr_vel2 = get_new_pos_vel(pr_acc1, pr_pos1, pr_vel1)

            #pr_acc2 = model(pr_pos2, pr_vel2, sample_masses)

            #loss2, dists2 = loss_fn(pr_acc2, sample_acc2, model.num_neighbors)

            #all_dists.append(dists2.item())

            #l += loss2 * 0.2

            losses.append(l)

        total_loss = 64 * sum(losses) / len(batch)
        all_losses.append(total_loss.item())
        total_loss.backward()

        optimizer.step()

    print(f'Train Loss: {sum(all_losses)/len(all_losses)}, Train L2: {sum(all_dists)/len(all_dists)}')

        


            

In [8]:

def train(model, train_data, val_dir, batch_size, loss_fn, optimizer, num_epochs, weights_dir=None, device='cuda', eval=True, use_custom_loss=False):
    
        model.to(device)
    
        if weights_dir is not None:
            weight_paths = os.listdir(weights_dir)
            weight_paths.sort(key=lambda x: int(x.split('_')[1].split('.')[0]))
            try:
                model.load_state_dict(torch.load(os.path.join(weights_dir, weight_paths[-1])))
                print(f'Loaded weights from {weight_paths[-1]}')
                last_model = int(weight_paths[-1].split('_')[1].split('.')[0])
                last_model += 1
            except:
                last_model = 0
        
        for epoch in range(num_epochs):
            print(f'Epoch {epoch}')
            train_epoch(model, train_data, batch_size, loss_fn, optimizer, device, use_custom_loss)
    
            #if eval:
            #   val(model, val_dir, batch_size, loss_fn, device)
    
        torch.save(model.state_dict(), f'./models/model_{last_model}.pt')
    
        return model

In [9]:
# clear cuda cache

torch.cuda.empty_cache()

In [None]:

model_files = os.listdir('./models/')
model_files.sort(key=lambda x: int(x.split('_')[1].split('.')[0]))
try:
    last_model_id = model_files[-1].split('_')[1].split('.')[0]
except IndexError:
    last_model_id = -1
last_model_id = int(last_model_id)

model = KwisatzHaderach(activation=True, layer_channels=[64, 64, 64, 3], calc_neighbors=True)



optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.7, verbose=True)

for i in range(40):
    # remove all files from train directory
    '''
    files = os.listdir('./train')
    for file in files:
        os.remove(os.path.join('./train', file))
    '''
    dataset = generate_dataset_memory(25, window_size=3)

    
    model = train(model, dataset, './val', 64, loss_fn, optimizer, 1, './models', device='cuda', eval=False, use_custom_loss=True)
    del dataset
    torch.cuda.empty_cache()
    last_model_id += 1
    torch.save(model.state_dict(), f'./models/model_{last_model_id}.pt')

    if i % 5 == 0:
        scheduler.step()



In [4]:
model = KwisatzHaderach()

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.7, verbose=True)

trainer = Trainer(loss_fn=mean_distance, batch_size=64, device='cpu', mode='present')

trainer.train(model=model, optimizer=optimizer, scheduler=scheduler, rounds=40, epochs_per_dataset=1, scenes_per_dataset=5)


Adjusting learning rate of group 0 to 1.0000e-03.
Round 0
Generating dataset with 5 scenes...


100%|██████████| 5/5 [01:00<00:00, 12.09s/it]


Epoch 0


100%|██████████| 78/78 [04:31<00:00,  3.48s/it]


Train Loss: 1797.9547455616487
Adjusting learning rate of group 0 to 7.0000e-04.
Saved weights to ./models/model_0.pt
Round 1
Generating dataset with 5 scenes...


  0%|          | 0/5 [00:12<?, ?it/s]


KeyboardInterrupt: 