In [5]:
import torch
from model import KwisatzHaderach
import json
import os
import tqdm

In [2]:
def euclidean_distance(a, b, epsilon=1e-9):
    return torch.sqrt(torch.sum((a - b)**2, dim=-1) + epsilon)

def loss_fn(pr_pos, gt_pos, num_fluid_neighbors):
    gamma = 0.5
    neighbor_scale = 1 / 40
    importance = torch.exp(-neighbor_scale * num_fluid_neighbors)
    return torch.mean(importance *
                        euclidean_distance(pr_pos, gt_pos)**gamma)

In [9]:
def train_epoch(model, file, batch_size, loss_fn, optimizer):

    model.train()
    with open(file) as f:
        data = json.load(f)

    num_batches = len(data) // batch_size
    all_losses = []
    for i in tqdm.tqdm(range(num_batches)):
        batch = data[i*batch_size:(i+1)*batch_size]
        m = [b['masses'] for b in batch]
        pos0 = [b['pos'] for b in batch]
        vel0 = [b['vel'] for b in batch]
        pos1 = [b['pos_next1'] for b in batch]
        vel1 = [b['vel_next1'] for b in batch]
        pos2 = [b['pos_next2'] for b in batch]
        vel2 = [b['vel_next2'] for b in batch]

        optimizer.zero_grad()
        loss = 0
        for j in range(len(batch)):
            sample_masses = torch.tensor(m[j], dtype=torch.float32).unsqueeze(1)
            sample_pos0 = torch.tensor(pos0[j], dtype=torch.float32)
            sample_vel0 = torch.tensor(vel0[j], dtype=torch.float32)
            sample_pos1 = torch.tensor(pos1[j], dtype=torch.float32)
            sample_vel1 = torch.tensor(vel1[j], dtype=torch.float32)
            sample_pos2 = torch.tensor(pos2[j], dtype=torch.float32)
            sample_vel2 = torch.tensor(vel2[j], dtype=torch.float32)
            


            pr_pos1, pr_vel1 = model(sample_pos0, sample_vel0, sample_masses)

            loss += 0.5*loss_fn(pr_pos1, sample_pos1, model.num_neighbors)

            pr_pos2, pr_vel2 = model(pr_pos1, pr_vel1, sample_masses)

            loss += 0.5*loss_fn(pr_pos2, sample_pos2, model.num_neighbors)

        loss /= len(batch)
        loss.backward()

        optimizer.step()
        all_losses.append(loss.item())
    print(f'Train Loss: {sum(all_losses)/len(all_losses)}')

def val(model, val_dir, batch_size, loss_fn):
    files = os.listdir(val_dir)

    model.eval()
    neigbor_num = []
    with torch.no_grad():
        total_loss = 0
        for file in files:
            with open(os.path.join(val_dir, file)) as f:
                data = json.load(f)
            num_batches = len(data) // batch_size
            for i in range(num_batches):
                batch = data[i*batch_size:(i+1)*batch_size]
                m = [b['masses'] for b in batch]
                pos0 = [b['pos'] for b in batch]
                vel0 = [b['vel'] for b in batch]
                pos1 = [b['pos_next1'] for b in batch]
                vel1 = [b['vel_next1'] for b in batch]
                pos2 = [b['pos_next2'] for b in batch]
                vel2 = [b['vel_next2'] for b in batch]

                neihbors_batch = []

                loss = 0
                for j in range(len(batch)):
                    sample_masses = torch.tensor(m[j], dtype=torch.float32).unsqueeze(1)
                    sample_pos0 = torch.tensor(pos0[j], dtype=torch.float32)
                    sample_vel0 = torch.tensor(vel0[j], dtype=torch.float32)
                    sample_pos1 = torch.tensor(pos1[j], dtype=torch.float32)
                    sample_vel1 = torch.tensor(vel1[j], dtype=torch.float32)
                    sample_pos2 = torch.tensor(pos2[j], dtype=torch.float32)
                    sample_vel2 = torch.tensor(vel2[j], dtype=torch.float32)

                    neighbors = 0

                    pr_pos1, pr_vel1 = model(sample_pos0, sample_vel0, sample_masses)
                    loss += 0.5*loss_fn(pr_pos1, sample_pos1, model.num_neighbors)

                    neighbors += model.num_neighbors * 0.5

                    pr_pos2, pr_vel2 = model(pr_pos1, pr_vel1, sample_masses)
                    loss += 0.5*loss_fn(pr_pos2, sample_pos2, model.num_neighbors)
                    
                    neighbors += model.num_neighbors * 0.5
                    neihbors_batch.append(neighbors)
                neigbor_num.append(sum(neihbors_batch)/len(neihbors_batch))


                loss /= len(batch)
                total_loss += loss.item()
    print(f'Validation Loss: {total_loss/len(files)}')
    print(f'Average Neighbors: {sum(neigbor_num)/len(neigbor_num)}')

            

In [10]:
def train(model, train_dir, val_dir, batch_size, loss_fn, optimizer, num_epochs):
    train_files = os.listdir(train_dir)
    for epoch in range(num_epochs):
        for t_file in train_files:
            full_path = os.path.join(train_dir, t_file)
            train_epoch(model, full_path, batch_size, loss_fn, optimizer)

        val(model, val_dir, batch_size, loss_fn)

        torch.save(model.state_dict(), f'./models/model_{epoch}.pt')

In [11]:
model = KwisatzHaderach()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
train(model, './train', './val', 128, loss_fn, optimizer, 5)



100%|██████████| 38/38 [03:52<00:00,  6.11s/it]


Train Loss: 0.4094385755689521


100%|██████████| 38/38 [03:39<00:00,  5.77s/it]


Train Loss: 0.15175206920034007


JSONDecodeError: Expecting ',' delimiter: line 5288890 column 36 (char 146894873)

In [12]:
val(model, './val', 128, loss_fn)

Validation Loss: 0.9065078496932983
