In [1]:
import torch
from kwisatzHaderach import KwisatzHaderach
import json
import os
import tqdm
import numpy as np
from datagen import generate_dataset, generate_dataset_memory, generate_dataset_memory_bh

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
def euclidean_distance(a, b):
    return torch.sqrt(torch.sum((a - b)**2, dim=-1) + 1e-12)
'''
def loss_fn(pr_acc, gt_acc, num_neighbors):
    gamma = 0.5
    neighbor_scale = 1/100
    importance = torch.exp(neighbor_scale * num_neighbors) # removed minus sign to give more importance to particles with more neighbors
    importance = importance / torch.max(importance)
    euclidean_distances = euclidean_distance(pr_acc, gt_acc)
    if importance.size()[0] == 0:
        importance = 1.0
    return torch.mean(importance *
                        euclidean_distances**gamma), torch.mean(euclidean_distances)
'''

'\ndef loss_fn(pr_acc, gt_acc, num_neighbors):\n    gamma = 0.5\n    neighbor_scale = 1/100\n    importance = torch.exp(neighbor_scale * num_neighbors) # removed minus sign to give more importance to particles with more neighbors\n    importance = importance / torch.max(importance)\n    euclidean_distances = euclidean_distance(pr_acc, gt_acc)\n    if importance.size()[0] == 0:\n        importance = 1.0\n    return torch.mean(importance *\n                        euclidean_distances**gamma), torch.mean(euclidean_distances)\n'

In [3]:
def loss_fn(pr_acc, gt_acc, num_neighbors, gt_pos, pr_pos, pos_importance=1.0):
    gamma = 0.5
    neighbor_scale = 1/100
    importance = torch.exp(neighbor_scale * num_neighbors) # removed minus sign to give more importance to particles with more neighbors
    importance = importance / torch.max(importance)
    euclidean_distances = euclidean_distance(pr_acc, gt_acc)
    if importance.size()[0] == 0:
        importance = 1.0
    return torch.mean(importance *
                        euclidean_distances**gamma) + pos_importance * torch.mean(importance * euclidean_distance(pr_pos, gt_pos)**gamma), torch.mean(euclidean_distances)

In [4]:
def get_new_pos_vel(acc, pos, vel, dt=0.01):
    new_vel = vel + acc * dt
    new_pos = pos + new_vel * dt
    return new_pos, new_vel

In [5]:
def train_epoch(model, file, batch_size, loss_fn, optimizer, device):

    model.train()
    with open(file) as f:
        data = json.load(f)

    num_batches = len(data) // batch_size
    all_losses = []
    all_dists = []
    for i in tqdm.tqdm(range(num_batches)):
        batch = data[i*batch_size:(i+1)*batch_size]
        m = torch.tensor([b['masses'] for b in batch], dtype=torch.float32).to(device)
        pos0 = torch.tensor([b['pos'] for b in batch], dtype=torch.float32).to(device)
        vel0 = torch.tensor([b['vel'] for b in batch], dtype=torch.float32).to(device)
        pos1 = torch.tensor([b['pos_next1'] for b in batch], dtype=torch.float32).to(device)
        pos2 = torch.tensor([b['pos_next2'] for b in batch], dtype=torch.float32).to(device)

        optimizer.zero_grad()
        losses = []
        for j in range(len(batch)):
            l = 0
            sample_masses = m[j].unsqueeze(1)
            sample_pos0 = pos0[j]
            sample_vel0 = vel0[j]
            sample_pos1 = pos1[j]
            sample_pos2 = pos2[j]

            pr_pos1, pr_vel1 = model(sample_pos0, sample_vel0, sample_masses)

            loss1, dists1 = loss_fn(pr_pos1, sample_pos1, model.num_neighbors)

            all_dists.append(dists1.item()*0.5)

            l += 0.5*loss1

            pr_pos2, pr_vel2 = model(pr_pos1, pr_vel1, sample_masses)

            loss2, dists2 = loss_fn(pr_pos2, sample_pos2, model.num_neighbors)

            all_dists.append(dists2.item()*0.5)

            l += 0.5*loss2


            losses.append(l)

        total_loss = sum(losses) / len(batch)
        all_losses.append(total_loss.item())
        total_loss.backward()

        optimizer.step()

    print(f'Train Loss: {sum(all_losses)/len(all_losses)}, Train L2: {sum(all_dists)/len(all_dists)}')

def train_epoch_memory(model, data, batch_size, loss_fn, optimizer, device, use_custom_loss=False):

    model.train()

    num_batches = len(data) // batch_size
    all_losses = []
    all_dists = []
    for i in tqdm.tqdm(range(num_batches)):
        batch = data[i*batch_size:(i+1)*batch_size]
        m = torch.tensor([b['masses'] for b in batch], dtype=torch.float32).to(device)
        pos0 = torch.tensor([b['pos'] for b in batch], dtype=torch.float32).to(device)
        vel0 = torch.tensor([b['vel'] for b in batch], dtype=torch.float32).to(device)
        acc0 = torch.tensor([b['acc'] for b in batch], dtype=torch.float32).to(device)
        acc1 = torch.tensor([b['acc_next1'] for b in batch], dtype=torch.float32).to(device)
        if use_custom_loss:
            pos1 = torch.tensor([b['pos_next1'] for b in batch], dtype=torch.float32).to(device)
            pos2 = torch.tensor([b['pos_next2'] for b in batch], dtype=torch.float32).to(device)
        #acc2 = torch.tensor([b['acc_next2'] for b in batch], dtype=torch.float32).to(device)

        optimizer.zero_grad()
        losses = []
        for j in range(len(batch)):
            l = 0
            sample_masses = m[j].unsqueeze(1)
            sample_pos0 = pos0[j]
            sample_vel0 = vel0[j]
            sample_acc0 = acc0[j]
            sample_acc1 = acc1[j]
            if use_custom_loss:
                sample_pos1 = pos1[j]
                sample_pos2 = pos2[j]

           

            pr_acc0 = model(sample_pos0, sample_vel0, sample_masses)
            pr_pos1, pr_vel1 = get_new_pos_vel(pr_acc0, sample_pos0, sample_vel0)

            if use_custom_loss:
                loss0, dists0 = loss_fn(pr_acc0, sample_acc0, model.num_neighbors, pr_pos1, sample_pos1)
            else:
                loss0 = torch.mean(euclidean_distance(pr_acc0, sample_acc0))
                dists0 = loss0

            all_dists.append(dists0)

            l += loss0 * 0.5

            

            pr_acc1 = model(pr_pos1, pr_vel1, sample_masses)
            pr_pos2, pr_vel2 = get_new_pos_vel(pr_acc1, pr_pos1, pr_vel1)

            if use_custom_loss:
                loss1, dists1 = loss_fn(pr_acc1, sample_acc1, model.num_neighbors, pr_pos2, sample_pos2)
            else:
                loss1 = torch.mean(euclidean_distance(pr_acc1, sample_acc1))
                dists1 = loss1

            all_dists.append(dists1.item())

            l += loss1 * 0.5

            #pr_pos2, pr_vel2 = get_new_pos_vel(pr_acc1, pr_pos1, pr_vel1)

            #pr_acc2 = model(pr_pos2, pr_vel2, sample_masses)

            #loss2, dists2 = loss_fn(pr_acc2, sample_acc2, model.num_neighbors)

            #all_dists.append(dists2.item())

            #l += loss2 * 0.2

            losses.append(l)

        total_loss = 64 * sum(losses) / len(batch)
        all_losses.append(total_loss.item())
        total_loss.backward()

        optimizer.step()

    print(f'Train Loss: {sum(all_losses)/len(all_losses)}, Train L2: {sum(all_dists)/len(all_dists)}')

def train_epoch_memory_black_hole_info(model, data, batch_size, loss_fn, optimizer, device,use_custom_loss=False):
    
        model.train()
    
        num_batches = len(data) // batch_size
        all_losses = []
        all_dists = []
        for i in tqdm.tqdm(range(num_batches)):
            batch = data[i*batch_size:(i+1)*batch_size]
            m = torch.tensor([b['masses'] for b in batch], dtype=torch.float32).to(device)
            pos0 = torch.tensor([b['pos'] for b in batch], dtype=torch.float32).to(device)
            vel0 = torch.tensor([b['vel'] for b in batch], dtype=torch.float32).to(device)
            acc0 = torch.tensor([b['acc'] for b in batch], dtype=torch.float32).to(device)
            acc1 = torch.tensor([b['acc_next1'] for b in batch], dtype=torch.float32).to(device)
            #acc2 = torch.tensor([b['acc_next2'] for b in batch], dtype=torch.float32).to(device)
            black_hole_indexes = torch.tensor(np.array([b['bh_index'] for b in batch]), dtype=torch.long)
    
            optimizer.zero_grad()
            losses = []
            for j in range(len(batch)):
                l = 0
                sample_masses = m[j].unsqueeze(1)
                sample_pos0 = pos0[j]
                sample_vel0 = vel0[j]
                sample_acc0 = acc0[j]
                sample_acc1 = acc1[j]
                sample_pos0_bh = sample_pos0[black_hole_indexes[j]]
                sample_vel0_bh = sample_vel0[black_hole_indexes[j]]
                sample_masses_bh = sample_masses[black_hole_indexes[j]]
    
                pr_acc0 = model(sample_pos0, sample_vel0, sample_masses, sample_pos0_bh, sample_vel0_bh, sample_masses_bh)

                if use_custom_loss:
    
                    loss0, dists0 = loss_fn(pr_acc0, sample_acc0, model.num_neighbors)

                else:
                    loss0 = torch.mean(euclidean_distance(pr_acc0, sample_acc0))
                    dists0 = loss0
    
                all_dists.append(dists0.item())
    
                l += loss0 * 0.5
    
                pr_pos1, pr_vel1 = get_new_pos_vel(pr_acc0, sample_pos0, sample_vel0)
    
                pr_acc1 = model(pr_pos1, pr_vel1, sample_masses, pr_pos1[black_hole_indexes[j]], pr_vel1[black_hole_indexes[j]], sample_masses[black_hole_indexes[j]])

                if use_custom_loss:
                    loss1, dists1 = loss_fn(pr_acc1, sample_acc1, model.num_neighbors)

                else:
                    loss1 = torch.mean(euclidean_distance(pr_acc1, sample_acc1))
                    dists1 = loss1
    
                all_dists.append(dists1.item())
    
                l += loss1 * 0.5
    
                #pr_pos2, pr_vel2 = get_new_pos_vel(pr_acc1,

                #pr_acc2 = model(pr_pos2, pr_vel2, sample_masses)

                #loss2, dists2 = loss_fn(pr_acc2, sample_acc2, model.num_neighbors)

                #all_dists.append(dists2.item())

                #l += loss2 * 0.2

                losses.append(l)

            total_loss = sum(losses) / len(batch)
            all_losses.append(total_loss.item())
            total_loss.backward()

            optimizer.step()

        print(f'Train Loss: {sum(all_losses)/len(all_losses)}, Train L2: {sum(all_dists)/len(all_dists)}')

        

def val(model, val_dir, batch_size, loss_fn, device):
    files = os.listdir(val_dir)

    model.eval()
    with torch.no_grad():
        all_losses = []
        all_dists = []
        for file in files:
            with open(os.path.join(val_dir, file)) as f:
                data = json.load(f)
            m = [b['masses'] for b in data]
            pos0 = [b['pos'] for b in data]
            vel0 = [b['vel'] for b in data]
            pos1 = [b['pos_next1'] for b in data]
            pos2 = [b['pos_next2'] for b in data]

            loss = 0
            for j in range(len(data)):
                sample_masses = torch.tensor(m[j], dtype=torch.float32).unsqueeze(1).to(device)
                sample_pos0 = torch.tensor(pos0[j], dtype=torch.float32).to(device)
                sample_vel0 = torch.tensor(vel0[j], dtype=torch.float32).to(device)
                sample_pos1 = torch.tensor(pos1[j], dtype=torch.float32).to(device)
                sample_pos2 = torch.tensor(pos2[j], dtype=torch.float32).to(device)


                pr_pos1, pr_vel1 = model(sample_pos0, sample_vel0, sample_masses)

                loss1, dist1 = loss_fn(pr_pos1, sample_pos1, model.num_neighbors)
                loss += 0.5*loss1

                all_dists.append(dist1.item())


                pr_pos2, _ = model(pr_pos1, pr_vel1, sample_masses)
                loss2, dist2 = loss_fn(pr_pos2, sample_pos2, model.num_neighbors)
                loss += 0.5*loss2

                all_dists.append(dist2.item())


            loss = loss / len(data)
            all_losses.append(loss.item())

            # clear memory
            del sample_masses
            del sample_pos0
            del sample_vel0
            del sample_pos1
            del sample_pos2
            del pr_pos1
            del pr_vel1
            del pr_pos2
            torch.cuda.empty_cache()
            


        print(f'Val Loss: {sum(all_losses)/len(all_losses)}, Val L2: {sum(all_dists)/len(all_dists)}')


            

In [6]:
def train(model, train_dir, val_dir, batch_size, loss_fn, optimizer, num_epochs, weights_dir=None, device='cuda', eval=True):

    model.to(device)

    if weights_dir is not None:
        weight_paths = os.listdir(weights_dir)
        weight_paths.sort()
        try:
            model.load_state_dict(torch.load(os.path.join(weights_dir, weight_paths[-1])))
            last_model = int(weight_paths[-1].split('_')[1].split('.')[0])
            last_model += 1
        except:
            last_model = 0
    
    train_files = os.listdir(train_dir)
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}')
        for t_file in train_files:
            full_path = os.path.join(train_dir, t_file)
            train_epoch(model, full_path, batch_size, loss_fn, optimizer, device)

        if eval:
            val(model, val_dir, batch_size, loss_fn, device)

    torch.save(model.state_dict(), f'./models/model_{last_model}.pt')

    return model

def train_memory(model, train_data, val_dir, batch_size, loss_fn, optimizer, num_epochs, weights_dir=None, device='cuda', eval=True, use_custom_loss=False):
    
        model.to(device)
    
        if weights_dir is not None:
            weight_paths = os.listdir(weights_dir)
            weight_paths.sort(key=lambda x: int(x.split('_')[1].split('.')[0]))
            try:
                model.load_state_dict(torch.load(os.path.join(weights_dir, weight_paths[-1])))
                print(f'Loaded weights from {weight_paths[-1]}')
                last_model = int(weight_paths[-1].split('_')[1].split('.')[0])
                last_model += 1
            except:
                last_model = 0
        
        for epoch in range(num_epochs):
            print(f'Epoch {epoch}')
            train_epoch_memory(model, train_data, batch_size, loss_fn, optimizer, device, use_custom_loss)
    
            if eval:
                val(model, val_dir, batch_size, loss_fn, device)
    
        torch.save(model.state_dict(), f'./models/model_{last_model}.pt')
    
        return model

def train_memory_black_hole_info(model, train_data, val_dir, batch_size, loss_fn, optimizer, num_epochs, weights_dir=None, device='cuda', eval=True, use_custom_loss=False):
        
            model.to(device)
        
            if weights_dir is not None:
                weight_paths = os.listdir(weights_dir)
                weight_paths.sort(key=lambda x: int(x.split('_')[1].split('.')[0]))
                try:
                    model.load_state_dict(torch.load(os.path.join(weights_dir, weight_paths[-1])))
                    print(f'Loaded weights from {weight_paths[-1]}')
                    last_model = int(weight_paths[-1].split('_')[1].split('.')[0])
                    last_model += 1
                except:
                    last_model = 0
            
            for epoch in range(num_epochs):
                print(f'Epoch {epoch}')
                train_epoch_memory_black_hole_info(model, train_data, batch_size, loss_fn, optimizer, device, use_custom_loss)
        
                if eval:
                    val(model, val_dir, batch_size, loss_fn, device)
        
            torch.save(model.state_dict(), f'./modelsbh/model_{last_model}.pt')
        
            return model

In [7]:
# clear cuda cache

torch.cuda.empty_cache()

In [8]:

model_files = os.listdir('./models/')
model_files.sort(key=lambda x: int(x.split('_')[1].split('.')[0]))
try:
    last_model_id = model_files[-1].split('_')[1].split('.')[0]
except IndexError:
    last_model_id = -1
last_model_id = int(last_model_id)

model = KwisatzHaderach(activation=True, layer_channels=[64, 64, 64, 3], calc_neighbors=True)



optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.7, verbose=True)

for i in range(40):
    # remove all files from train directory
    '''
    files = os.listdir('./train')
    for file in files:
        os.remove(os.path.join('./train', file))
    '''
    dataset = generate_dataset_memory(25, window_size=3)

    
    model = train_memory(model, dataset, './val', 64, loss_fn, optimizer, 1, './models', device='cuda', eval=False, use_custom_loss=True)
    del dataset
    torch.cuda.empty_cache()
    last_model_id += 1
    torch.save(model.state_dict(), f'./models/model_{last_model_id}.pt')

    if i % 5 == 0:
        scheduler.step()



Adjusting learning rate of group 0 to 1.0000e-03.
Generating dataset with 25 scenes...


  0%|          | 0/25 [00:00<?, ?it/s]

100%|██████████| 25/25 [06:22<00:00, 15.29s/it]


Loaded weights from model_6.pt
Epoch 0


100%|██████████| 389/389 [06:54<00:00,  1.07s/it]


Train Loss: 26.175111755920252, Train L2: 1.2557165622711182
Adjusting learning rate of group 0 to 7.0000e-04.
Generating dataset with 25 scenes...


100%|██████████| 25/25 [06:00<00:00, 14.41s/it]


Loaded weights from model_7.pt
Epoch 0


100%|██████████| 389/389 [06:38<00:00,  1.02s/it]


Train Loss: 23.87988822564368, Train L2: 1.1771546602249146
Generating dataset with 25 scenes...


100%|██████████| 25/25 [06:06<00:00, 14.65s/it]


Loaded weights from model_8.pt
Epoch 0


100%|██████████| 389/389 [06:57<00:00,  1.07s/it]


Train Loss: 27.791678195747433, Train L2: 1.3150672912597656
Generating dataset with 25 scenes...


100%|██████████| 25/25 [05:57<00:00, 14.31s/it]


Loaded weights from model_9.pt
Epoch 0


100%|██████████| 389/389 [06:44<00:00,  1.04s/it]


Train Loss: 23.65862972387005, Train L2: 1.1502450704574585
Generating dataset with 25 scenes...


100%|██████████| 25/25 [06:01<00:00, 14.45s/it]


Loaded weights from model_10.pt
Epoch 0


100%|██████████| 389/389 [06:59<00:00,  1.08s/it]


Train Loss: 25.859502238295686, Train L2: 1.2504315376281738
Generating dataset with 25 scenes...


100%|██████████| 25/25 [06:02<00:00, 14.51s/it]


Loaded weights from model_11.pt
Epoch 0


100%|██████████| 389/389 [06:56<00:00,  1.07s/it]


Train Loss: 25.47246709397029, Train L2: 1.1412330865859985
Adjusting learning rate of group 0 to 4.9000e-04.
Generating dataset with 25 scenes...


100%|██████████| 25/25 [05:59<00:00, 14.37s/it]


Loaded weights from model_12.pt
Epoch 0


100%|██████████| 389/389 [06:45<00:00,  1.04s/it]


Train Loss: 23.735741720714422, Train L2: 1.1494419574737549
Generating dataset with 25 scenes...


100%|██████████| 25/25 [06:02<00:00, 14.50s/it]


Loaded weights from model_13.pt
Epoch 0


100%|██████████| 389/389 [06:51<00:00,  1.06s/it]


Train Loss: 23.699333681238954, Train L2: 1.1125510931015015
Generating dataset with 25 scenes...


100%|██████████| 25/25 [07:10<00:00, 17.23s/it]


Loaded weights from model_14.pt
Epoch 0


100%|██████████| 389/389 [06:47<00:00,  1.05s/it]


Train Loss: 23.317258648516894, Train L2: 1.0810562372207642
Generating dataset with 25 scenes...


100%|██████████| 25/25 [07:13<00:00, 17.33s/it]


Loaded weights from model_15.pt
Epoch 0


100%|██████████| 389/389 [06:52<00:00,  1.06s/it]


Train Loss: 23.59692185703464, Train L2: 1.043887972831726
Generating dataset with 25 scenes...


100%|██████████| 25/25 [07:13<00:00, 17.33s/it]


Loaded weights from model_16.pt
Epoch 0


100%|██████████| 389/389 [06:49<00:00,  1.05s/it]


Train Loss: 22.92708967645248, Train L2: 1.054594874382019
Adjusting learning rate of group 0 to 3.4300e-04.
Generating dataset with 25 scenes...


100%|██████████| 25/25 [07:06<00:00, 17.07s/it]


Loaded weights from model_17.pt
Epoch 0


100%|██████████| 389/389 [06:52<00:00,  1.06s/it]


Train Loss: 24.68969833452475, Train L2: 1.1208006143569946
Generating dataset with 25 scenes...


100%|██████████| 25/25 [07:11<00:00, 17.28s/it]


Loaded weights from model_18.pt
Epoch 0


100%|██████████| 389/389 [06:53<00:00,  1.06s/it]


Train Loss: 24.26265456192291, Train L2: 1.1171727180480957
Generating dataset with 25 scenes...


100%|██████████| 25/25 [05:59<00:00, 14.37s/it]


Loaded weights from model_19.pt
Epoch 0


100%|██████████| 389/389 [06:45<00:00,  1.04s/it]


Train Loss: 23.439054332233027, Train L2: 1.0519728660583496
Generating dataset with 25 scenes...


100%|██████████| 25/25 [05:59<00:00, 14.38s/it]


Loaded weights from model_20.pt
Epoch 0


100%|██████████| 389/389 [06:37<00:00,  1.02s/it]


Train Loss: 21.767386301624132, Train L2: 0.9704540967941284
Generating dataset with 25 scenes...


100%|██████████| 25/25 [06:05<00:00, 14.60s/it]


Loaded weights from model_21.pt
Epoch 0


100%|██████████| 389/389 [06:51<00:00,  1.06s/it]


Train Loss: 23.38499410477273, Train L2: 1.0275403261184692
Adjusting learning rate of group 0 to 2.4010e-04.
Generating dataset with 25 scenes...


100%|██████████| 25/25 [07:08<00:00, 17.16s/it]


Loaded weights from model_22.pt
Epoch 0


100%|██████████| 389/389 [06:53<00:00,  1.06s/it]


Train Loss: 22.97784591027581, Train L2: 1.056542992591858
Generating dataset with 25 scenes...


100%|██████████| 25/25 [07:05<00:00, 17.03s/it]


Loaded weights from model_23.pt
Epoch 0


100%|██████████| 389/389 [06:46<00:00,  1.04s/it]


Train Loss: 20.41898385908426, Train L2: 0.9251740574836731
Generating dataset with 25 scenes...


100%|██████████| 25/25 [06:00<00:00, 14.43s/it]


Loaded weights from model_24.pt
Epoch 0


100%|██████████| 389/389 [06:57<00:00,  1.07s/it]


Train Loss: 23.825658352025677, Train L2: 1.0740584135055542
Generating dataset with 25 scenes...


100%|██████████| 25/25 [06:00<00:00, 14.42s/it]


Loaded weights from model_25.pt
Epoch 0


100%|██████████| 389/389 [06:46<00:00,  1.04s/it]


Train Loss: 22.642220416228383, Train L2: 0.9703188538551331
Generating dataset with 25 scenes...


100%|██████████| 25/25 [06:02<00:00, 14.49s/it]


Loaded weights from model_26.pt
Epoch 0


100%|██████████| 389/389 [06:56<00:00,  1.07s/it]


Train Loss: 22.397291006956124, Train L2: 0.9734132289886475
Adjusting learning rate of group 0 to 1.6807e-04.
Generating dataset with 25 scenes...


100%|██████████| 25/25 [06:06<00:00, 14.66s/it]


Loaded weights from model_27.pt
Epoch 0


100%|██████████| 389/389 [06:53<00:00,  1.06s/it]


Train Loss: 24.645934985045603, Train L2: 1.1168736219406128
Generating dataset with 25 scenes...


100%|██████████| 25/25 [05:56<00:00, 14.27s/it]


Loaded weights from model_28.pt
Epoch 0


100%|██████████| 389/389 [06:43<00:00,  1.04s/it]


Train Loss: 22.025538584260524, Train L2: 0.9367703199386597
Generating dataset with 25 scenes...


100%|██████████| 25/25 [06:02<00:00, 14.49s/it]


Loaded weights from model_29.pt
Epoch 0


100%|██████████| 389/389 [06:51<00:00,  1.06s/it]


Train Loss: 21.168042151051498, Train L2: 0.8852216005325317
Generating dataset with 25 scenes...


100%|██████████| 25/25 [05:54<00:00, 14.18s/it]


Loaded weights from model_30.pt
Epoch 0


100%|██████████| 389/389 [06:49<00:00,  1.05s/it]


Train Loss: 21.615131689527654, Train L2: 0.9396917819976807
Generating dataset with 25 scenes...


100%|██████████| 25/25 [05:56<00:00, 14.26s/it]


Loaded weights from model_31.pt
Epoch 0


100%|██████████| 389/389 [06:58<00:00,  1.07s/it]


Train Loss: 25.080363045621347, Train L2: 1.1189464330673218
Adjusting learning rate of group 0 to 1.1765e-04.
Generating dataset with 25 scenes...


100%|██████████| 25/25 [05:56<00:00, 14.28s/it]


Loaded weights from model_32.pt
Epoch 0


100%|██████████| 389/389 [06:48<00:00,  1.05s/it]


Train Loss: 24.169333730074925, Train L2: 1.003359317779541
Generating dataset with 25 scenes...


100%|██████████| 25/25 [05:56<00:00, 14.27s/it]


Loaded weights from model_33.pt
Epoch 0


100%|██████████| 389/389 [06:44<00:00,  1.04s/it]


Train Loss: 21.809192667277124, Train L2: 0.926364541053772
Generating dataset with 25 scenes...


100%|██████████| 25/25 [05:58<00:00, 14.33s/it]


Loaded weights from model_34.pt
Epoch 0


100%|██████████| 389/389 [06:57<00:00,  1.07s/it]


Train Loss: 22.94473246008081, Train L2: 0.9450706839561462
Generating dataset with 25 scenes...


100%|██████████| 25/25 [05:58<00:00, 14.32s/it]


Loaded weights from model_35.pt
Epoch 0


100%|██████████| 389/389 [06:50<00:00,  1.06s/it]


Train Loss: 22.35299638488299, Train L2: 0.9176115393638611
Generating dataset with 25 scenes...


100%|██████████| 25/25 [05:59<00:00, 14.40s/it]


Loaded weights from model_36.pt
Epoch 0


100%|██████████| 389/389 [06:54<00:00,  1.06s/it]


Train Loss: 21.514200313538694, Train L2: 0.8935184478759766
Adjusting learning rate of group 0 to 8.2354e-05.
Generating dataset with 25 scenes...


100%|██████████| 25/25 [05:59<00:00, 14.37s/it]


Loaded weights from model_37.pt
Epoch 0


100%|██████████| 389/389 [06:56<00:00,  1.07s/it]


Train Loss: 23.88959942991752, Train L2: 1.0756733417510986
Generating dataset with 25 scenes...


100%|██████████| 25/25 [05:54<00:00, 14.19s/it]


Loaded weights from model_38.pt
Epoch 0


100%|██████████| 389/389 [07:01<00:00,  1.08s/it]


Train Loss: 24.21579461968037, Train L2: 1.032889485359192
Generating dataset with 25 scenes...


100%|██████████| 25/25 [05:56<00:00, 14.25s/it]


Loaded weights from model_39.pt
Epoch 0


100%|██████████| 389/389 [06:53<00:00,  1.06s/it]


Train Loss: 22.806940696539794, Train L2: 0.9270835518836975
Generating dataset with 25 scenes...


100%|██████████| 25/25 [06:01<00:00, 14.47s/it]


Loaded weights from model_40.pt
Epoch 0


100%|██████████| 389/389 [07:00<00:00,  1.08s/it]


Train Loss: 23.440120299861487, Train L2: 1.0079116821289062
Generating dataset with 25 scenes...


100%|██████████| 25/25 [06:01<00:00, 14.44s/it]


Loaded weights from model_41.pt
Epoch 0


100%|██████████| 389/389 [06:50<00:00,  1.06s/it]


Train Loss: 21.769327202921968, Train L2: 0.9063311815261841
Adjusting learning rate of group 0 to 5.7648e-05.
Generating dataset with 25 scenes...


100%|██████████| 25/25 [05:59<00:00, 14.37s/it]


Loaded weights from model_42.pt
Epoch 0


100%|██████████| 389/389 [06:43<00:00,  1.04s/it]


Train Loss: 19.98553610460924, Train L2: 0.8556334972381592
Generating dataset with 25 scenes...


100%|██████████| 25/25 [05:55<00:00, 14.23s/it]


Loaded weights from model_43.pt
Epoch 0


100%|██████████| 389/389 [06:45<00:00,  1.04s/it]


Train Loss: 21.520572020033025, Train L2: 0.904503345489502
Generating dataset with 25 scenes...


100%|██████████| 25/25 [05:57<00:00, 14.31s/it]


Loaded weights from model_44.pt
Epoch 0


100%|██████████| 389/389 [06:57<00:00,  1.07s/it]


Train Loss: 24.8705400795434, Train L2: 1.0444378852844238
Generating dataset with 25 scenes...


100%|██████████| 25/25 [05:55<00:00, 14.21s/it]


Loaded weights from model_45.pt
Epoch 0


100%|██████████| 389/389 [06:45<00:00,  1.04s/it]


Train Loss: 19.359581989003942, Train L2: 0.742699146270752


In [9]:
from kwisatzHaderach_bh import KwisatzHaderachBH

In [10]:
'''
model_files = os.listdir('./modelsbh/')
model_files.sort(key=lambda x: int(x.split('_')[1].split('.')[0]))
try:
    last_model_id = model_files[-1].split('_')[1].split('.')[0]
except IndexError:
    last_model_id = -1
last_model_id = int(last_model_id)

model = KwisatzHaderachBH(activation=True, layer_channels=[64, 64, 32, 3], calc_neighbors=False)



optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.7, verbose=True)

for i in range(40):
    # remove all files from train directory
    files = os.listdir('./train')
    for file in files:
        os.remove(os.path.join('./train', file))
    
    dataset = generate_dataset_memory_bh(2, window_size=2)

    
    model = train_memory_black_hole_info(model, dataset, './val', 16, loss_fn, optimizer, 1, './modelsbh', device='cpu', eval=False, use_custom_loss=False)
    del dataset
    last_model_id += 1
    torch.save(model.state_dict(), f'./modelsbh/model_{last_model_id}.pt')

    if i % 5 == 0:
        scheduler.step()
       ''' 

"\nmodel_files = os.listdir('./modelsbh/')\nmodel_files.sort(key=lambda x: int(x.split('_')[1].split('.')[0]))\ntry:\n    last_model_id = model_files[-1].split('_')[1].split('.')[0]\nexcept IndexError:\n    last_model_id = -1\nlast_model_id = int(last_model_id)\n\nmodel = KwisatzHaderachBH(activation=True, layer_channels=[64, 64, 32, 3], calc_neighbors=False)\n\n\n\noptimizer = torch.optim.Adam(model.parameters(), lr=1e-2)\nscheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.7, verbose=True)\n\nfor i in range(40):\n    # remove all files from train directory\n    files = os.listdir('./train')\n    for file in files:\n        os.remove(os.path.join('./train', file))\n    \n    dataset = generate_dataset_memory_bh(2, window_size=2)\n\n    \n    model = train_memory_black_hole_info(model, dataset, './val', 16, loss_fn, optimizer, 1, './modelsbh', device='cpu', eval=False, use_custom_loss=False)\n    del dataset\n    last_model_id += 1\n    torch.save(model.state_dict(), f'./mo