In [21]:
import torch
from kwisatzHaderach import KwisatzHaderach
import json
import os
import tqdm
from datagen import generate_dataset, generate_dataset_memory

In [22]:
def euclidean_distance(a, b):
    return torch.linalg.norm(a - b, axis=-1) + 1e-12

def loss_fn(pr_pos, gt_pos, num_fluid_neighbors):
    gamma = 0.5
    neighbor_scale = 1 / 50
    importance = torch.exp(-neighbor_scale * num_fluid_neighbors)
    euclidean_distances = euclidean_distance(pr_pos, gt_pos)
    if importance.size()[0] == 0:
        return torch.mean(euclidean_distances), torch.mean(euclidean_distances)
    return torch.mean(importance *
                        euclidean_distances**gamma), torch.mean(euclidean_distances)

In [23]:
def train_epoch(model, file, batch_size, loss_fn, optimizer, device):

    model.train()
    with open(file) as f:
        data = json.load(f)

    num_batches = len(data) // batch_size
    all_losses = []
    all_dists = []
    for i in tqdm.tqdm(range(num_batches)):
        batch = data[i*batch_size:(i+1)*batch_size]
        m = torch.tensor([b['masses'] for b in batch], dtype=torch.float32).to(device)
        pos0 = torch.tensor([b['pos'] for b in batch], dtype=torch.float32).to(device)
        vel0 = torch.tensor([b['vel'] for b in batch], dtype=torch.float32).to(device)
        pos1 = torch.tensor([b['pos_next1'] for b in batch], dtype=torch.float32).to(device)
        pos2 = torch.tensor([b['pos_next2'] for b in batch], dtype=torch.float32).to(device)

        optimizer.zero_grad()
        losses = []
        for j in range(len(batch)):
            l = 0
            sample_masses = m[j].unsqueeze(1)
            sample_pos0 = pos0[j]
            sample_vel0 = vel0[j]
            sample_pos1 = pos1[j]
            sample_pos2 = pos2[j]

            pr_pos1, pr_vel1 = model(sample_pos0, sample_vel0, sample_masses)

            loss1, dists1 = loss_fn(pr_pos1, sample_pos1, model.num_neighbors)

            all_dists.append(dists1.item()*0.5)

            l += 0.5*loss1

            pr_pos2, pr_vel2 = model(pr_pos1, pr_vel1, sample_masses)

            loss2, dists2 = loss_fn(pr_pos2, sample_pos2, model.num_neighbors)

            all_dists.append(dists2.item()*0.5)

            l += 0.5*loss2


            losses.append(l)

        total_loss = 128 * sum(losses) / len(batch)
        all_losses.append(total_loss.item())
        total_loss.backward()

        optimizer.step()

    print(f'Train Loss: {sum(all_losses)/len(all_losses)}, Train L2: {sum(all_dists)/len(all_dists)}')

def train_epoch_memory(model, data, batch_size, loss_fn, optimizer, device):

    model.train()

    num_batches = len(data) // batch_size
    all_losses = []
    all_dists = []
    for i in tqdm.tqdm(range(num_batches)):
        batch = data[i*batch_size:(i+1)*batch_size]
        m = torch.tensor([b['masses'] for b in batch], dtype=torch.float32).to(device)
        pos0 = torch.tensor([b['pos'] for b in batch], dtype=torch.float32).to(device)
        vel0 = torch.tensor([b['vel'] for b in batch], dtype=torch.float32).to(device)
        pos1 = torch.tensor([b['pos_next1'] for b in batch], dtype=torch.float32).to(device)
        pos2 = torch.tensor([b['pos_next2'] for b in batch], dtype=torch.float32).to(device)
        pos3 = torch.tensor([b['pos_next3'] for b in batch], dtype=torch.float32).to(device)
        pos4 = torch.tensor([b['pos_next4'] for b in batch], dtype=torch.float32).to(device)


        optimizer.zero_grad()
        losses = []
        for j in range(len(batch)):
            l = 0
            sample_masses = m[j].unsqueeze(1)
            sample_pos0 = pos0[j]
            sample_vel0 = vel0[j]
            sample_pos1 = pos1[j]
            sample_pos2 = pos2[j]


            pr_pos1, pr_vel1 = model(sample_pos0, sample_vel0, sample_masses)

            loss1, dists1 = loss_fn(pr_pos1, sample_pos1, model.num_neighbors)

            all_dists.append(dists1.item())

            l += 0.1*loss1

            pr_pos2, pr_vel2 = model(pr_pos1, pr_vel1, sample_masses)

            loss2, dists2 = loss_fn(pr_pos2, sample_pos2, model.num_neighbors)

            all_dists.append(dists2.item())

            l += 0.2*loss2

            pr_pos3, pr_vel3 = model(pr_pos2, pr_vel2, sample_masses)

            loss3, dists3 = loss_fn(pr_pos3, pos3[j], model.num_neighbors)

            all_dists.append(dists3.item())

            l += 0.3*loss3

            pr_pos4, pr_vel4 = model(pr_pos3, pr_vel3, sample_masses)

            loss4, dists4 = loss_fn(pr_pos4, pos4[j], model.num_neighbors)

            all_dists.append(dists4.item())

            l += 0.4*loss4




            losses.append(l)

        total_loss = 128 * sum(losses) / len(batch)
        all_losses.append(total_loss.item())
        total_loss.backward()

        optimizer.step()

    print(f'Train Loss: {sum(all_losses)/len(all_losses)}, Train L2: {sum(all_dists)/len(all_dists)}')



def val(model, val_dir, batch_size, loss_fn, device):
    files = os.listdir(val_dir)

    model.eval()
    with torch.no_grad():
        all_losses = []
        all_dists = []
        for file in files:
            with open(os.path.join(val_dir, file)) as f:
                data = json.load(f)
            m = [b['masses'] for b in data]
            pos0 = [b['pos'] for b in data]
            vel0 = [b['vel'] for b in data]
            pos1 = [b['pos_next1'] for b in data]
            pos2 = [b['pos_next2'] for b in data]

            loss = 0
            for j in range(len(data)):
                sample_masses = torch.tensor(m[j], dtype=torch.float32).unsqueeze(1).to(device)
                sample_pos0 = torch.tensor(pos0[j], dtype=torch.float32).to(device)
                sample_vel0 = torch.tensor(vel0[j], dtype=torch.float32).to(device)
                sample_pos1 = torch.tensor(pos1[j], dtype=torch.float32).to(device)
                sample_pos2 = torch.tensor(pos2[j], dtype=torch.float32).to(device)


                pr_pos1, pr_vel1 = model(sample_pos0, sample_vel0, sample_masses)

                loss1, dist1 = loss_fn(pr_pos1, sample_pos1, model.num_neighbors)
                loss += 0.5*loss1

                all_dists.append(dist1.item())


                pr_pos2, _ = model(pr_pos1, pr_vel1, sample_masses)
                loss2, dist2 = loss_fn(pr_pos2, sample_pos2, model.num_neighbors)
                loss += 0.5*loss2

                all_dists.append(dist2.item())


            loss = loss / len(data)
            all_losses.append(loss.item())

            # clear memory
            del sample_masses
            del sample_pos0
            del sample_vel0
            del sample_pos1
            del sample_pos2
            del pr_pos1
            del pr_vel1
            del pr_pos2
            torch.cuda.empty_cache()
            


        print(f'Val Loss: {sum(all_losses)/len(all_losses)}, Val L2: {sum(all_dists)/len(all_dists)}')


            

In [24]:
def train(model, train_dir, val_dir, batch_size, loss_fn, optimizer, num_epochs, weights_dir=None, device='cuda', eval=True):

    model.to(device)

    if weights_dir is not None:
        weight_paths = os.listdir(weights_dir)
        weight_paths.sort()
        try:
            model.load_state_dict(torch.load(os.path.join(weights_dir, weight_paths[-1])))
            last_model = int(weight_paths[-1].split('_')[1].split('.')[0])
            last_model += 1
        except:
            last_model = 0
    
    train_files = os.listdir(train_dir)
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}')
        for t_file in train_files:
            full_path = os.path.join(train_dir, t_file)
            train_epoch(model, full_path, batch_size, loss_fn, optimizer, device)

        if eval:
            val(model, val_dir, batch_size, loss_fn, device)

    torch.save(model.state_dict(), f'./models/model_{last_model}.pt')

    return model

def train_memory(model, train_data, val_dir, batch_size, loss_fn, optimizer, num_epochs, weights_dir=None, device='cuda', eval=True):
    
        model.to(device)
    
        if weights_dir is not None:
            weight_paths = os.listdir(weights_dir)
            weight_paths.sort(key=lambda x: int(x.split('_')[1].split('.')[0]))
            try:
                model.load_state_dict(torch.load(os.path.join(weights_dir, weight_paths[-1])))
                print(f'Loaded weights from {weight_paths[-1]}')
                last_model = int(weight_paths[-1].split('_')[1].split('.')[0])
                last_model += 1
            except:
                last_model = 0
        
        for epoch in range(num_epochs):
            print(f'Epoch {epoch}')
            train_epoch_memory(model, train_data, batch_size, loss_fn, optimizer, device)
    
            if eval:
                val(model, val_dir, batch_size, loss_fn, device)
    
        torch.save(model.state_dict(), f'./models/model_{last_model}.pt')
    
        return model

In [25]:
model_files = os.listdir('./models/')
model_files.sort(key=lambda x: int(x.split('_')[1].split('.')[0]))
try:
    last_model_id = model_files[-1].split('_')[1].split('.')[0]
except IndexError:
    last_model_id = -1
last_model_id = int(last_model_id)

model = KwisatzHaderach(activation=True, layer_channels=[64, 64, 32, 3])



optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.5, verbose=True)

for i in range(10):
    # remove all files from train directory
    files = os.listdir('./train')
    for file in files:
        os.remove(os.path.join('./train', file))
    
    dataset = generate_dataset_memory(25)

    
    model = train_memory(model, dataset, './val', 16, loss_fn, optimizer, 1, './models', device='cuda', eval=False)
    del dataset
    last_model_id += 1
    torch.save(model.state_dict(), f'./models/model_{last_model_id}.pt')

    if i % 2 == 0:
        scheduler.step()


Adjusting learning rate of group 0 to 1.0000e-03.
Generating dataset with 25 scenes...


100%|██████████| 25/25 [06:57<00:00, 16.71s/it]


Loaded weights from model_10.pt
Epoch 0


100%|██████████| 1548/1548 [16:42<00:00,  1.54it/s]


Train Loss: 0.00705320183103495, Train L2: 0.0009586746525507151
Adjusting learning rate of group 0 to 5.0000e-04.
Generating dataset with 25 scenes...


100%|██████████| 25/25 [06:08<00:00, 14.74s/it]


Loaded weights from model_11.pt
Epoch 0


100%|██████████| 1548/1548 [16:41<00:00,  1.55it/s]


Train Loss: 0.006597265546334647, Train L2: 0.0009881174070850065
Generating dataset with 25 scenes...


100%|██████████| 25/25 [06:13<00:00, 14.93s/it]


Loaded weights from model_12.pt
Epoch 0


100%|██████████| 1548/1548 [16:45<00:00,  1.54it/s]


Train Loss: 0.006672489968367951, Train L2: 0.0009955053426794471
Adjusting learning rate of group 0 to 2.5000e-04.
Generating dataset with 25 scenes...


  0%|          | 0/25 [00:06<?, ?it/s]


KeyboardInterrupt: 