<h1>NN Model</h1>

In [1]:
from pycocotools.coco import COCO
import matplotlib
import matplotlib.pyplot as plt
import os
import cv2
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as functions
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
#import torchvision
from torchvision import transforms
import re

In [2]:
print("cuda" if torch.cuda.is_available() else "cpu")

cuda


In [3]:
DATADIR = "cocodoom/"
USED_RUNS = ["run1", "run2", "run3"]

dataSplit, TRAIN_RUN = "run-full-train", "run1"

annFile = '{}{}.json'.format(DATADIR,dataSplit)

In [4]:
coco_train = COCO(annFile)

loading annotations into memory...
Done (t=22.88s)
creating index...
index created!


Done (t=22.94s)
creating index...


index created!


In [5]:
dataSplit, VAL_RUN = "run-full-val", "run2"

annFile = '{}{}.json'.format(DATADIR,dataSplit)

In [6]:
coco_val = COCO(annFile)

loading annotations into memory...
Done (t=20.82s)
creating index...
index created!


Done (t=20.81s)
creating index...


index created!


In [7]:
dataSplit, TEST_RUN = "run-full-test", "run3"

annFile = '{}{}.json'.format(DATADIR,dataSplit)

In [8]:
coco_test = COCO(annFile)

loading annotations into memory...
Done (t=11.72s)
creating index...
index created!


Done (t=11.72s)
creating index...


index created!


In [9]:
player_positions = {"run1":[], "run2":[], "run3":[]}
motion_vectors = {"run1":[], "run2":[], "run3":[]}

for run in USED_RUNS:
    with open(DATADIR+run+"/log.txt", 'r') as log_file:
        for line in log_file:
            if "player" in line:
                line = line.strip()
                tic, stats = line.split("player:")
                x, y, z, angle = stats.split(",")
    
                # Store position in the dictionary
                player_positions[run].append((float(x), float(y), float(z), float(angle)))
                if len(player_positions[run]) >= 2:
                    player_position = player_positions[run][-1]
                    prev_player_position = player_positions[run][-2]
                    
                    dx = player_position[0] - prev_player_position[0]
                    dy = player_position[1] - prev_player_position[1]
                    dz = player_position[2] - prev_player_position[2]
                    dangle = np.pi - abs(abs(player_position[3] - prev_player_position[3]) - np.pi)
                    
                    dx_relative = dx * np.cos(2 * np.pi - prev_player_position[3]) + dy * np.cos(prev_player_position[3] - 1/2 * np.pi)
                    dy_relative = dx * np.sin(2 * np.pi - prev_player_position[3]) + dy * np.sin(prev_player_position[3] - 1/2 * np.pi)
                    motion_vector = (dx_relative, dy_relative, dz, dangle)
                    motion_vectors[run].append(motion_vector)

In [44]:
class DoomMotionDataset(Dataset):
    def __init__(self, coco, run, input_window, prediction_window, transform=None):
        self.coco = coco
        self.run = run
        self.img_ids = self.coco.getImgIds()
        self.transform = transform
        self.input_window = input_window
        self.prediction_window = prediction_window

    def __len__(self):
        return len(self.img_ids)

    def fullSegmentationFormat(self, rgb_filename):
        seg_image = self.load_image(self.getSegmentationMask(DATADIR + rgb_filename))
        if seg_image == None:
            return seg_image
        seg_class_map = self.color_to_index(seg_image)
        seg_class_one_hot = functions.one_hot(seg_class_map, num_classes=4).to(dtype=torch.float).permute(2, 0, 1)
        return seg_class_one_hot

    def fullDepthFormat(self, rgb_filename):
        depth_mask = self.load_image(self.getDepthMask(DATADIR + rgb_filename))
        if depth_mask == None:
            return depth_mask
        depth_mask = torch.tensor(depth_mask, dtype=torch.float32)
        return depth_mask

    def getSegmentationMask(self, rgb_filename):
        return rgb_filename.replace("rgb", "objects")

    def getDepthMask(self, rgb_filename):
        return rgb_filename.replace("rgb", "depth")

    def color_to_index(self, segmentation_image):
        # Map colors to class indices
        r, g, b = segmentation_image
        pixel_values = r + (g *  2**8) + (b * 2**16)  # From cocodoom documentation, converts to an object id

        class_map = torch.full_like(pixel_values, 3, dtype=torch.long)

        sky = (1 << 23) + 0
        horizontal = (1 << 23) + 1
        vertical = (1 << 23) + 2
        
        class_map[x == sky] = 0
        class_map[x == horizontal] = 1
        class_map[x == vertical] = 2
        return class_map

    def load_image(self, path):
        if os.path.exists(path):
            img = Image.open(path)
            return transforms.ToTensor()(img)
        return None

    def __getitem__(self, idx):
        # Load the RGB image
        rgb_filename = self.coco.loadImgs(self.img_ids[idx])[0]['file_name']
        print(rgb_filename)
        tic = int(rgb_filename.replace(".png", "").split("/")[-1])
        next_tic = tic+1
        previous_tic = tic-1
        prev_motion_vectors = []
        next_motion_vectors = []
        prev_seg = []
        prev_dep = []

        for t in range(input_window, 0, -1):
            if tic-t < 0:
                prev_motion_vectors.append(motion_vectors[self.run][0])
                prev_filename = self.coco.loadImgs(self.img_ids[0])[0]['file_name']
                seg = self.fullSegmentationFormat(prev_filename)
                dep = self.fullDepthFormat(prev_filename)
                prev_seg.append(seg)
                prev_dep.append(dep)
                continue
            elif tic-t >= len(motion_vectors[self.run]):
                prev_motion_vectors.append(motion_vectors[self.run][-1])
                prev_filename = self.coco.loadImgs(self.img_ids[-1])[0]['file_name']
                seg = self.fullSegmentationFormat(prev_filename)
                dep = self.fullDepthFormat(prev_filename)
                prev_seg.append(seg)
                prev_dep.append(dep)
                continue
            prev_motion_vectors.append(motion_vectors[self.run][tic-t])
            prev_filename = rgb_filename[:-10] + str(max(tic - t, 2)).rjust(6, "0") + ".png"
            # run1/map01/rgb/000002.png
            if os.path.exists(DATADIR + prev_filename):
                seg = self.fullSegmentationFormat(prev_filename)
                #print(f"seg shape: {seg.shape}")
                dep = self.fullDepthFormat(prev_filename)
                #print(f"dep shape: {dep.shape}")
                prev_seg.append(seg)
                prev_dep.append(dep)
            else:
                prev_seg.append(torch.zeros((4, 200, 320)))
                prev_dep.append(torch.zeros((1, 200, 320)))
                

        for t in range(1, prediction_window+1):
            if tic+t >= len(motion_vectors[self.run]):
                next_motion_vectors.append(motion_vectors[self.run][-1])
                continue
            next_motion_vectors.append(motion_vectors[self.run][tic+t])

        # if dx > 1000:
        #     print(f"idx: {idx}")
        #     print(f"rgb_filename: {rgb_filename}")
        #     print(f"tic: {tic}")
        #     print(f"next_tic: {next_tic}")
        #     print(f"previous_tic: {previous_tic}")
        #     print(f"Sus {idx}")
        #     print(f"prev_player_position: {prev_player_position}")
        #     print(f"player_position: {player_position}")
        #     print(f"next_player_position: {next_player_position}")
        #     print(f"prev_motion_vector: {prev_motion_vector}")
        #     print(f"next_motion_vector: {next_motion_vector}")

        #print(prev_motion_vectors)
        #print(next_motion_vectors)
            
        prev_motion_vectors = torch.tensor(prev_motion_vectors, dtype=torch.float32)
        next_motion_vectors = torch.tensor(next_motion_vectors, dtype=torch.float32)
        #print(len(prev_seg))
        prev_seg = torch.stack(prev_seg)
        prev_dep = torch.stack(prev_dep)
        
        return {"prev_motion" : prev_motion_vectors, "next_motion" : next_motion_vectors, "previous_seg" : prev_seg, "previous_dep" : prev_dep}


In [45]:
class NeuralNetwork(nn.Module):
  def __init__(self, batch_size, input_length, sequence_length, activation_function=functions.relu, device=torch.device("cpu")):
    super(NeuralNetwork, self).__init__()
    self.batch_size = batch_size
    self.input_length = input_length
    self.sequence_length = sequence_length

    # Encoder
    # Conv layers
    self.conv_seg = nn.Conv2d(4, 1, kernel_size=3, stride=2, padding=1, bias=False).to(device)
    self.conv_dep = nn.Conv2d(1, 1, kernel_size=3, stride=2, padding=1, bias=False).to(device)

    self.motion_fc = nn.Linear(4, 32).to(device)
      
    # Pre-fusion LSTMs
    self.vis_LSTM = nn.LSTM(input_size=32000, hidden_size=256, batch_first=True).to(device)
    self.inertia_LSTM = nn.LSTM(input_size=32, hidden_size=256, batch_first=True).to(device)

    # Fusion LSTM
    self.fusion_LSTM = nn.LSTM(input_size=512, hidden_size=256, batch_first=True).to(device)

    # Decoder
    self.de_motion_fc = nn.Linear(4, 32).to(device)
    self.de_vis_LSTM = nn.LSTM(input_size=32, hidden_size=256, batch_first=True).to(device) #Unsure what the input size of this should be as it actually receives nothing
    self.de_inertia_LSTM = nn.LSTM(input_size=32, hidden_size=256, batch_first=True).to(device)
    self.de_fusion_LSTM = nn.LSTM(input_size=512, hidden_size=256, batch_first=True).to(device)
    self.output_fc = nn.Linear(256, 4).to(device)

  def forward(self, segmentation, depth, prev_motion):
    hidden_vis = None
    hidden_inert = None
    hidden_fus = None
    
    for t in range(self.input_length):
        #print(segmentation.shape)
        seg = self.conv_seg(segmentation[:,t])
        #print(seg.shape)
        dep = self.conv_dep(depth[:,t])
        #print(dep.shape)
        mot = self.motion_fc(prev_motion[:,t])
        vis = torch.cat((seg, dep), dim=1)
        vis = torch.flatten(vis, start_dim=1)
        #print(vis.shape)
        if hidden_vis != None:
            output_vis, hidden_vis = self.vis_LSTM(vis, hidden_vis)
        else:
            output_vis, hidden_vis = self.vis_LSTM(vis)
        if hidden_inert != None:
            output_inert, hidden_inert = self.inertia_LSTM(mot, hidden_inert)
        else:
            output_inert, hidden_inert = self.inertia_LSTM(mot)
        combined = torch.cat((output_vis, output_inert), dim=1)
        if hidden_fus != None:
            _, hidden_fus = self.fusion_LSTM(combined, hidden_fus)
        else:
            _, hidden_fus = self.fusion_LSTM(combined)

    #print("Prev motion: " + str(prev_motion.shape))
    de_mot = prev_motion[:,-1]
    output_tensor = torch.zeros(self.sequence_length, self.batch_size, 4).to(segmentation.device)
    for t in range(self.sequence_length):
        #print(de_mot.shape)
        de_mot = self.de_motion_fc(de_mot)
        de_output_inert, hidden_inert = self.de_inertia_LSTM(de_mot, hidden_inert)
        de_output_vis, hidden_vis = self.de_vis_LSTM(torch.zeros(self.batch_size, 32).to(segmentation.device), hidden_vis)
        #print(de_output_vis.shape, de_output_inert.shape)
        combined = torch.cat((de_output_vis, de_output_inert), dim=1)
        de_output_fus, hidden_fus = self.de_fusion_LSTM(combined, hidden_fus)
        #print("de_output_fus: " + str(de_output_fus.shape))
        output_t = self.output_fc(de_output_fus)
        #print("output_t: " + str(output_t.shape))
        #output_t = output_t.unsqueeze(0)
        de_mot = output_t
        output_tensor[t] = output_t.unsqueeze(0)
        
    return output_tensor

In [46]:
torch.cuda.empty_cache()
batch_size = 256
learning_rate = 1e-3
num_epochs = 10
input_window = 5
prediction_window = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("cuda" if torch.cuda.is_available() else "cpu")
model = NeuralNetwork(batch_size, input_window, prediction_window, device=device).to(device)

train_dataset = DoomMotionDataset(coco_train, TRAIN_RUN, input_window, prediction_window)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)

val_dataset = DoomMotionDataset(coco_val, VAL_RUN, input_window, prediction_window)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Loss function and optimizer
criterion = torch.nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")
    for batch_idx, batch in enumerate(progress_bar):
        prev_motion, next_motion, previous_seg, previous_dep = batch["prev_motion"], batch["next_motion"], batch["previous_seg"], batch["previous_dep"]
        prev_motion, next_motion, previous_seg, previous_dep = prev_motion.to(device), next_motion.to(device), previous_seg.to(device), previous_dep.to(device)

        if prev_motion.shape[0] != next_motion.shape[0]:
                continue

        optimizer.zero_grad()

        outputs = model(previous_seg, previous_dep, prev_motion)
        outputs = outputs.permute(1, 0, 2)

        if outputs.size(0) != next_motion.size(0):
            continue
        
        loss = criterion(outputs, next_motion)
        loss.backward()

        optimizer.step()

        running_loss += loss.item()

        progress_bar.set_postfix({
            "batch_loss": loss.item(),
            "batch_index": batch_idx + 1,
            "batch_size": prev_motion.size(0)
        })

    # Average loss per epoch
    epoch_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1}, Loss: {epoch_loss:.4f}")

    model.eval()  # Set the model to evaluation mode
    running_loss = 0.0
    
    
    progress_bar = tqdm(val_loader, desc="Validation", unit="batch")
    
    with torch.no_grad():  # Disable gradient calculations for evaluation
        for batch_idx, batch in enumerate(progress_bar):
            prev_motion, next_motion, previous_seg, previous_dep = batch["prev_motion"], batch["next_motion"], batch["previous_seg"], batch["previous_dep"]
            prev_motion, next_motion, previous_seg, previous_dep = prev_motion.to(device), next_motion.to(device), previous_seg.to(device), previous_dep.to(device)

            if prev_motion.shape[0] != next_motion.shape[0]:
                continue
                
            outputs = model(previous_seg, previous_dep, prev_motion)
            outputs = outputs.permute(1, 0, 2)

            if outputs.size(0) != next_motion.size(0):
                continue
            
            loss = criterion(outputs, next_motion)
            
            running_loss += loss.item()
            
            progress_bar.set_postfix({
                "batch_loss": loss.item(),
                "batch_index": batch_idx + 1,
                "batch_size": prev_motion.size(0)
            })
    
    # Average loss over all batches
    val_loss = running_loss / len(val_loader)
    print(f"Val Loss: {val_loss:.4f}")

# Save the trained model
torch.save(model.state_dict(), "multimodal_seq2seq.pth")

cuda


  depth_mask = torch.tensor(depth_mask, dtype=torch.float32)


run1/map01/rgb/000002.png
run1/map01/rgb/000003.png
run1/map01/rgb/000004.png
run1/map01/rgb/000005.png
run1/map01/rgb/000006.png
run1/map01/rgb/000007.png
run1/map01/rgb/000008.png
run1/map01/rgb/000009.png
run1/map01/rgb/000010.png
run1/map01/rgb/000011.png
run1/map01/rgb/000012.png
run1/map01/rgb/000013.png
run1/map01/rgb/000014.png
run1/map01/rgb/000015.png
run1/map01/rgb/000016.png
run1/map01/rgb/000017.png
run1/map01/rgb/000018.png
run1/map01/rgb/000019.png
run1/map01/rgb/000020.png
run1/map01/rgb/000021.png
run1/map01/rgb/000022.png
run1/map01/rgb/000023.png
run1/map01/rgb/000024.png
run1/map01/rgb/000025.png
run1/map01/rgb/000026.png
run1/map01/rgb/000027.png
run1/map01/rgb/000028.png
run1/map01/rgb/000029.png
run1/map01/rgb/000030.png
run1/map01/rgb/000031.png
run1/map01/rgb/000032.png
run1/map01/rgb/000033.png
run1/map01/rgb/000034.png
run1/map01/rgb/000035.png
run1/map01/rgb/000036.png
run1/map01/rgb/000037.png
run1/map01/rgb/000038.png
run1/map01/rgb/000039.png
run1/map01/r

Epoch 1/10:   0%|                  | 1/991 [00:13<3:47:34, 13.79s/batch, batch_loss=32.6, batch_index=1, batch_size=256]

run1/map01/rgb/000258.png
run1/map01/rgb/000259.png
run1/map01/rgb/000260.png
run1/map01/rgb/000261.png
run1/map01/rgb/000262.png
run1/map01/rgb/000263.png
run1/map01/rgb/000264.png
run1/map01/rgb/000265.png
run1/map01/rgb/000266.png
run1/map01/rgb/000267.png
run1/map01/rgb/000268.png
run1/map01/rgb/000269.png
run1/map01/rgb/000270.png
run1/map01/rgb/000271.png
run1/map01/rgb/000272.png
run1/map01/rgb/000273.png
run1/map01/rgb/000274.png
run1/map01/rgb/000275.png
run1/map01/rgb/000276.png
run1/map01/rgb/000277.png
run1/map01/rgb/000278.png
run1/map01/rgb/000279.png
run1/map01/rgb/000280.png
run1/map01/rgb/000281.png
run1/map01/rgb/000282.png
run1/map01/rgb/000283.png
run1/map01/rgb/000284.png
run1/map01/rgb/000285.png
run1/map01/rgb/000286.png
run1/map01/rgb/000287.png
run1/map01/rgb/000288.png
run1/map01/rgb/000289.png
run1/map01/rgb/000290.png
run1/map01/rgb/000291.png
run1/map01/rgb/000292.png
run1/map01/rgb/000293.png
run1/map01/rgb/000294.png
run1/map01/rgb/000295.png
run1/map01/r

Epoch 1/10:   0%|                  | 2/991 [00:27<3:47:42, 13.81s/batch, batch_loss=32.5, batch_index=2, batch_size=256]

run1/map01/rgb/000514.png
run1/map01/rgb/000515.png
run1/map01/rgb/000516.png
run1/map01/rgb/000517.png
run1/map01/rgb/000518.png
run1/map01/rgb/000519.png
run1/map01/rgb/000520.png
run1/map01/rgb/000521.png
run1/map01/rgb/000522.png
run1/map01/rgb/000523.png
run1/map01/rgb/000524.png
run1/map01/rgb/000525.png
run1/map01/rgb/000526.png
run1/map01/rgb/000527.png
run1/map01/rgb/000528.png
run1/map01/rgb/000529.png
run1/map01/rgb/000530.png
run1/map01/rgb/000531.png
run1/map01/rgb/000532.png
run1/map01/rgb/000533.png
run1/map01/rgb/000534.png
run1/map01/rgb/000535.png
run1/map01/rgb/000536.png
run1/map01/rgb/000537.png
run1/map01/rgb/000538.png
run1/map01/rgb/000539.png
run1/map01/rgb/000540.png
run1/map01/rgb/000541.png
run1/map01/rgb/000542.png
run1/map01/rgb/000543.png
run1/map01/rgb/000544.png
run1/map01/rgb/000545.png
run1/map01/rgb/000546.png
run1/map01/rgb/000547.png
run1/map01/rgb/000548.png
run1/map01/rgb/000549.png
run1/map01/rgb/000550.png
run1/map01/rgb/000551.png
run1/map01/r

Epoch 1/10:   0%|                    | 3/991 [00:41<3:48:47, 13.89s/batch, batch_loss=27, batch_index=3, batch_size=256]

run1/map01/rgb/000770.png
run1/map01/rgb/000771.png
run1/map01/rgb/000772.png
run1/map01/rgb/000773.png
run1/map01/rgb/000774.png
run1/map01/rgb/000775.png
run1/map01/rgb/000776.png
run1/map01/rgb/000777.png
run1/map01/rgb/000778.png
run1/map01/rgb/000779.png
run1/map01/rgb/000780.png
run1/map01/rgb/000781.png
run1/map01/rgb/000782.png
run1/map01/rgb/000783.png
run1/map01/rgb/000784.png
run1/map01/rgb/000785.png
run1/map01/rgb/000786.png
run1/map01/rgb/000787.png
run1/map01/rgb/000788.png
run1/map01/rgb/000789.png
run1/map01/rgb/000790.png
run1/map01/rgb/000791.png
run1/map01/rgb/000792.png
run1/map01/rgb/000793.png
run1/map01/rgb/000794.png
run1/map01/rgb/000795.png
run1/map01/rgb/000796.png
run1/map01/rgb/000797.png
run1/map01/rgb/000798.png
run1/map01/rgb/000799.png
run1/map01/rgb/000800.png
run1/map01/rgb/000801.png
run1/map01/rgb/000802.png
run1/map01/rgb/000803.png
run1/map01/rgb/000804.png
run1/map01/rgb/000805.png
run1/map01/rgb/000806.png
run1/map01/rgb/000807.png
run1/map01/r

Epoch 1/10:   0%|                  | 4/991 [00:55<3:47:31, 13.83s/batch, batch_loss=12.4, batch_index=4, batch_size=256]

run1/map01/rgb/001026.png
run1/map01/rgb/001027.png
run1/map01/rgb/001028.png
run1/map01/rgb/001029.png
run1/map01/rgb/001030.png
run1/map01/rgb/001031.png
run1/map01/rgb/001032.png
run1/map01/rgb/001033.png
run1/map01/rgb/001034.png
run1/map01/rgb/001035.png
run1/map01/rgb/001036.png
run1/map01/rgb/001037.png
run1/map01/rgb/001038.png
run1/map01/rgb/001039.png
run1/map01/rgb/001040.png
run1/map01/rgb/001041.png
run1/map01/rgb/001042.png
run1/map01/rgb/001043.png
run1/map01/rgb/001044.png
run1/map01/rgb/001045.png
run1/map01/rgb/001046.png
run1/map01/rgb/001047.png
run1/map01/rgb/001048.png
run1/map01/rgb/001049.png
run1/map01/rgb/001050.png
run1/map01/rgb/001051.png
run1/map01/rgb/001052.png
run1/map01/rgb/001053.png
run1/map01/rgb/001054.png
run1/map01/rgb/001055.png
run1/map01/rgb/001056.png
run1/map01/rgb/001057.png
run1/map01/rgb/001058.png
run1/map01/rgb/001059.png
run1/map01/rgb/001060.png
run1/map01/rgb/001061.png
run1/map01/rgb/001062.png
run1/map01/rgb/001063.png
run1/map01/r

Epoch 1/10:   1%|                  | 5/991 [01:08<3:44:39, 13.67s/batch, batch_loss=41.4, batch_index=5, batch_size=256]

run1/map01/rgb/001282.png
run1/map01/rgb/001283.png
run1/map01/rgb/001284.png
run1/map01/rgb/001285.png
run1/map01/rgb/001286.png
run1/map01/rgb/001287.png
run1/map01/rgb/001288.png
run1/map01/rgb/001289.png
run1/map01/rgb/001290.png
run1/map01/rgb/001291.png
run1/map01/rgb/001292.png
run1/map01/rgb/001293.png
run1/map01/rgb/001294.png
run1/map01/rgb/001295.png
run1/map01/rgb/001296.png
run1/map01/rgb/001297.png
run1/map01/rgb/001298.png
run1/map01/rgb/001299.png
run1/map01/rgb/001300.png
run1/map01/rgb/001301.png
run1/map01/rgb/001302.png
run1/map01/rgb/001303.png
run1/map01/rgb/001304.png
run1/map01/rgb/001305.png
run1/map01/rgb/001306.png
run1/map01/rgb/001307.png
run1/map01/rgb/001308.png
run1/map01/rgb/001309.png
run1/map01/rgb/001310.png
run1/map01/rgb/001311.png
run1/map01/rgb/001312.png
run1/map01/rgb/001313.png
run1/map01/rgb/001314.png
run1/map01/rgb/001315.png
run1/map01/rgb/001316.png
run1/map01/rgb/001317.png
run1/map01/rgb/001318.png
run1/map01/rgb/001319.png
run1/map01/r

Epoch 1/10:   1%|                  | 6/991 [01:22<3:42:36, 13.56s/batch, batch_loss=34.9, batch_index=6, batch_size=256]

run1/map01/rgb/001538.png
run1/map01/rgb/001539.png
run1/map01/rgb/001540.png
run1/map01/rgb/001541.png
run1/map01/rgb/001542.png
run1/map01/rgb/001543.png
run1/map01/rgb/001544.png
run1/map01/rgb/001545.png
run1/map01/rgb/001546.png
run1/map01/rgb/001547.png
run1/map01/rgb/001548.png
run1/map01/rgb/001549.png
run1/map01/rgb/001550.png
run1/map01/rgb/001551.png
run1/map01/rgb/001552.png
run1/map01/rgb/001553.png
run1/map01/rgb/001554.png
run1/map01/rgb/001555.png
run1/map01/rgb/001556.png
run1/map01/rgb/001557.png
run1/map01/rgb/001558.png
run1/map01/rgb/001559.png
run1/map01/rgb/001560.png
run1/map01/rgb/001561.png
run1/map01/rgb/001562.png
run1/map01/rgb/001563.png
run1/map01/rgb/001564.png
run1/map01/rgb/001565.png
run1/map01/rgb/001566.png
run1/map01/rgb/001567.png
run1/map01/rgb/001568.png
run1/map01/rgb/001569.png
run1/map01/rgb/001570.png
run1/map01/rgb/001571.png
run1/map01/rgb/001572.png
run1/map01/rgb/001573.png
run1/map01/rgb/001574.png
run1/map01/rgb/001575.png
run1/map01/r

Epoch 1/10:   1%|▏                 | 7/991 [01:35<3:43:08, 13.61s/batch, batch_loss=30.5, batch_index=7, batch_size=256]

run1/map01/rgb/001794.png
run1/map01/rgb/001795.png
run1/map01/rgb/001796.png
run1/map01/rgb/001797.png
run1/map01/rgb/001798.png
run1/map01/rgb/001799.png
run1/map01/rgb/001800.png
run1/map01/rgb/001801.png
run1/map01/rgb/001802.png
run1/map01/rgb/001803.png
run1/map01/rgb/001804.png
run1/map01/rgb/001805.png
run1/map01/rgb/001806.png
run1/map01/rgb/001807.png
run1/map01/rgb/001808.png
run1/map01/rgb/001809.png
run1/map01/rgb/001810.png
run1/map01/rgb/001811.png
run1/map01/rgb/001812.png
run1/map01/rgb/001813.png
run1/map01/rgb/001814.png
run1/map01/rgb/001815.png
run1/map01/rgb/001816.png
run1/map01/rgb/001817.png
run1/map01/rgb/001818.png
run1/map01/rgb/001819.png
run1/map01/rgb/001820.png
run1/map01/rgb/001821.png
run1/map01/rgb/001822.png
run1/map01/rgb/001823.png
run1/map01/rgb/001824.png
run1/map01/rgb/001825.png
run1/map01/rgb/001826.png
run1/map01/rgb/001827.png
run1/map01/rgb/001828.png
run1/map01/rgb/001829.png
run1/map01/rgb/001830.png
run1/map01/rgb/001831.png
run1/map01/r

Epoch 1/10:   1%|▏                  | 8/991 [01:56<4:22:28, 16.02s/batch, batch_loss=610, batch_index=8, batch_size=256]

run1/map02/rgb/002088.png
run1/map02/rgb/002089.png
run1/map02/rgb/002090.png
run1/map02/rgb/002091.png
run1/map02/rgb/002092.png
run1/map02/rgb/002093.png
run1/map02/rgb/002094.png
run1/map02/rgb/002095.png
run1/map02/rgb/002096.png
run1/map02/rgb/002097.png
run1/map02/rgb/002098.png
run1/map02/rgb/002099.png
run1/map02/rgb/002100.png
run1/map02/rgb/002101.png
run1/map02/rgb/002102.png
run1/map02/rgb/002103.png
run1/map02/rgb/002104.png
run1/map02/rgb/002105.png
run1/map02/rgb/002106.png
run1/map02/rgb/002107.png
run1/map02/rgb/002108.png
run1/map02/rgb/002109.png
run1/map02/rgb/002110.png
run1/map02/rgb/002111.png
run1/map02/rgb/002112.png
run1/map02/rgb/002113.png
run1/map02/rgb/002114.png
run1/map02/rgb/002115.png
run1/map02/rgb/002116.png
run1/map02/rgb/002117.png
run1/map02/rgb/002118.png
run1/map02/rgb/002119.png
run1/map02/rgb/002120.png
run1/map02/rgb/002121.png
run1/map02/rgb/002122.png
run1/map02/rgb/002123.png
run1/map02/rgb/002124.png
run1/map02/rgb/002125.png
run1/map02/r

Epoch 1/10:   1%|▏                 | 9/991 [02:15<4:32:33, 16.65s/batch, batch_loss=19.5, batch_index=9, batch_size=256]

run1/map02/rgb/002344.png
run1/map02/rgb/002345.png
run1/map02/rgb/002346.png
run1/map02/rgb/002347.png
run1/map02/rgb/002348.png
run1/map02/rgb/002349.png
run1/map02/rgb/002350.png
run1/map02/rgb/002351.png
run1/map02/rgb/002352.png
run1/map02/rgb/002353.png
run1/map02/rgb/002354.png
run1/map02/rgb/002355.png
run1/map02/rgb/002356.png
run1/map02/rgb/002357.png
run1/map02/rgb/002358.png
run1/map02/rgb/002359.png
run1/map02/rgb/002360.png
run1/map02/rgb/002361.png
run1/map02/rgb/002362.png
run1/map02/rgb/002363.png
run1/map02/rgb/002364.png
run1/map02/rgb/002365.png
run1/map02/rgb/002366.png
run1/map02/rgb/002367.png
run1/map02/rgb/002368.png
run1/map02/rgb/002369.png
run1/map02/rgb/002370.png
run1/map02/rgb/002371.png
run1/map02/rgb/002372.png
run1/map02/rgb/002373.png
run1/map02/rgb/002374.png
run1/map02/rgb/002375.png
run1/map02/rgb/002376.png
run1/map02/rgb/002377.png
run1/map02/rgb/002378.png
run1/map02/rgb/002379.png
run1/map02/rgb/002380.png
run1/map02/rgb/002381.png
run1/map02/r

Epoch 1/10:   1%|▏               | 10/991 [02:31<4:30:18, 16.53s/batch, batch_loss=21.6, batch_index=10, batch_size=256]

run1/map02/rgb/002600.png
run1/map02/rgb/002601.png
run1/map02/rgb/002602.png
run1/map02/rgb/002603.png
run1/map02/rgb/002604.png
run1/map02/rgb/002605.png
run1/map02/rgb/002606.png
run1/map02/rgb/002607.png
run1/map02/rgb/002608.png
run1/map02/rgb/002609.png
run1/map02/rgb/002610.png
run1/map02/rgb/002611.png
run1/map02/rgb/002612.png
run1/map02/rgb/002613.png
run1/map02/rgb/002614.png
run1/map02/rgb/002615.png
run1/map02/rgb/002616.png
run1/map02/rgb/002617.png
run1/map02/rgb/002618.png
run1/map02/rgb/002619.png
run1/map02/rgb/002620.png
run1/map02/rgb/002621.png
run1/map02/rgb/002622.png
run1/map02/rgb/002623.png
run1/map02/rgb/002624.png
run1/map02/rgb/002625.png
run1/map02/rgb/002626.png
run1/map02/rgb/002627.png
run1/map02/rgb/002628.png
run1/map02/rgb/002629.png
run1/map02/rgb/002630.png
run1/map02/rgb/002631.png
run1/map02/rgb/002632.png
run1/map02/rgb/002633.png
run1/map02/rgb/002634.png
run1/map02/rgb/002635.png
run1/map02/rgb/002636.png
run1/map02/rgb/002637.png
run1/map02/r

Epoch 1/10:   1%|▏               | 10/991 [02:43<4:26:40, 16.31s/batch, batch_loss=21.6, batch_index=10, batch_size=256]


KeyboardInterrupt: 

Epoch 1/10:   0%|                            | 0/991 [00:14<?, ?batch/s, batch_loss=32.6, batch_index=1, batch_size=256]

Epoch 1/10:   0%|                  | 1/991 [00:14<3:58:23, 14.45s/batch, batch_loss=32.6, batch_index=1, batch_size=256]

Epoch 1/10:   0%|                  | 1/991 [00:29<3:58:23, 14.45s/batch, batch_loss=32.5, batch_index=2, batch_size=256]

Epoch 1/10:   0%|                  | 2/991 [00:29<4:02:24, 14.71s/batch, batch_loss=32.5, batch_index=2, batch_size=256]

Epoch 1/10:   0%|                  | 2/991 [00:43<4:02:24, 14.71s/batch, batch_loss=26.8, batch_index=3, batch_size=256]

Epoch 1/10:   0%|                  | 3/991 [00:43<3:57:47, 14.44s/batch, batch_loss=26.8, batch_index=3, batch_size=256]

Epoch 1/10:   0%|                  | 3/991 [00:57<3:57:47, 14.44s/batch, batch_loss=12.5, batch_index=4, batch_size=256]

Epoch 1/10:   0%|                  | 4/991 [00:57<3:53:30, 14.19s/batch, batch_loss=12.5, batch_index=4, batch_size=256]

Epoch 1/10:   0%|                    | 4/991 [01:11<3:53:30, 14.19s/batch, batch_loss=42, batch_index=5, batch_size=256]

Epoch 1/10:   1%|                    | 5/991 [01:11<3:55:50, 14.35s/batch, batch_loss=42, batch_index=5, batch_size=256]

Epoch 1/10:   1%|                  | 5/991 [01:24<3:55:50, 14.35s/batch, batch_loss=34.9, batch_index=6, batch_size=256]

Epoch 1/10:   1%|                  | 6/991 [01:24<3:46:56, 13.82s/batch, batch_loss=34.9, batch_index=6, batch_size=256]

Epoch 1/10:   1%|                    | 6/991 [01:38<3:46:56, 13.82s/batch, batch_loss=31, batch_index=7, batch_size=256]

Epoch 1/10:   1%|▏                   | 7/991 [01:38<3:46:50, 13.83s/batch, batch_loss=31, batch_index=7, batch_size=256]

Epoch 1/10:   1%|▏                   | 7/991 [01:41<3:57:25, 14.48s/batch, batch_loss=31, batch_index=7, batch_size=256]




RuntimeError: stack expects a non-empty TensorList

In [None]:
device = (torch.device("cuda" if torch.cuda.is_available() else "cpu"))
criterion = torch.nn.MSELoss()
model = NeuralNetwork(batch_size, input_window, prediction_window, device=device).to(device)
model.load_state_dict(torch.load("multimodal_seq2seq.pth", weights_only=True))

test_dataset = DoomMotionDataset(coco_test, TEST_RUN, input_window, prediction_window)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

model.eval()  # Set the model to evaluation mode
running_loss = 0.0


progress_bar = tqdm(test_loader, desc="Testing", unit="batch")

with torch.no_grad():  # Disable gradient calculations for evaluation
    for batch_idx, batch in enumerate(progress_bar):
        prev_motion, next_motion, previous_seg, previous_dep = batch["prev_motion"], batch["next_motion"], batch["previous_seg"], batch["previous_dep"]
        prev_motion, next_motion, previous_seg, previous_dep = prev_motion.to(device), next_motion.to(device), previous_seg.to(device), previous_dep.to(device)

        if prev_motion.shape[0] != next_motion.shape[0]:
                continue
            
        outputs = model(previous_seg, previous_dep, prev_motion)
        outputs = outputs.permute(1, 0, 2)

        if outputs.size(0) != next_motion.size(0):
            continue
        
        loss = criterion(outputs, next_motion)
        
        running_loss += loss.item()
        
        progress_bar.set_postfix({
            "batch_loss": loss.item(),
            "batch_index": batch_idx + 1,
            "batch_size": prev_motion.size(0)
        })

# Average loss over all batches
test_loss = running_loss / len(test_loader)
print(f"Test Loss: {test_loss:.4f}")