## Notebook for comparing filternet performance against unfiltered residual masks

In [1]:
import os 
import subprocess 
import numpy as np 
import torch 
import torch.nn as nn 
from torch.utils.data import Dataset, DataLoader 
from torchvision import transforms, utils 
import pandas as pd 
import matplotlib.pyplot as plt 
from typing import List, Dict, Tuple 
from tqdm import tqdm 
from scipy.spatial.transform import Rotation as R 
from PIL import Image 
from skimage import io, transform

## Dataloader

In [2]:
# Set device to cuda
device = torch.device('cuda:0')

# Create dataset class
class VideoDataset(Dataset):
    def __init__(self, path_to_data, augmented=False, num_vids=7, seq_length=5, img_size=256):
        """
        Go through each video folder and build a map from
        index i in range [0, N], where N is the total number
        of frames in all the videos, to a tuple (j, k), where
        j is the video number and k the frame index in that
        video. For reference, the videos are stored in the
        dataset as:

        data:
        - test1:
          - inputs:
            - bd_poses.csv
            - bd_twists.csv
            - frame0000.png
            - frame0001.png
            - ...
          - targets:
            - frame0000.png
            - frame0001.png
            - ...
        - test2:
            - ...
        - ...
        """
        self.img_size = img_size
        self.seq_length = seq_length
        self.augmented = augmented
        self.num_vids = num_vids
        self.path_to_data = path_to_data
        self.idx_map: List[Tuple[int, int]] = []
        self.poses: Dict[int, np.ndarray] = {}
        self.twists: Dict[int, np.ndarray] = {}
        for i in tqdm(range(1, self.num_vids + 1)):
            # Define path to pose
            pose_path = os.path.join(path_to_data, f'test{i}/inputs/bd_poses.csv')
            twist_path = os.path.join(path_to_data, f'test{i}/inputs/bd_twists.csv')

            # Check that files were opened properly
            if not os.path.isfile(pose_path):
                raise FileNotFoundError(f"Missing pose file: {pose_path}")
            if not os.path.isfile(twist_path):
                raise FileNotFoundError(f"Missing twist file: {twist_path}")

            # Get poses
            self.poses[i] = pd.read_csv(pose_path).to_numpy()
            self.twists[i] = pd.read_csv(twist_path).to_numpy()

            # Get number of sequences in this video: num_frames - (seq_length - 1)
            num_sequences = (self.poses[i].shape[0] - 1) - (self.seq_length - 1)

            # Update index map
            video_num = [i] * num_sequences
            frame_idx = list(range(0, num_sequences))
            self.idx_map.extend(list(zip(video_num, frame_idx)))

        self.total_num_sequences = len(self.idx_map)

    def __len__(self):
        """
        Return length of dataset as computed in __init__() function.
        """
        return self.total_num_sequences

    def __getitem__(self, idx):
        """
        Use map built in __init__() to retrieve the image,
        pose, and twist directly from the dataset.
        This avoids loading the entire dataset which
        overwhelms RAM.
        """
        assert idx < self.total_num_sequences
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # Return indices for saving processed images 
        indices_arr = np.zeros((self.seq_length, 2)) 

        seq_start = True
        for i in range(self.seq_length):
            # Define path to data
            test_idx, frame_idx = self.idx_map[idx]
            frame_idx += i
            padded_frame_idx = self.to_zero_pad_idx(frame_idx)
            input_img_path = os.path.join(self.path_to_data,
                                    f'test{test_idx}/inputs/frame{padded_frame_idx}.png')
            output_img_path = os.path.join(self.path_to_data,
                                    f'test{test_idx}/targets/frame{padded_frame_idx}.png')

            # store indices in indices array 
            indices_arr[i] = np.array([test_idx, frame_idx]) 

            # Load and process data
            input_frame = io.imread(input_img_path)
            output_frame = io.imread(output_img_path)
            input_frame = torch.from_numpy(self.to_grayscale(input_frame)).float()
            output_frame = torch.from_numpy(self.to_grayscale(output_frame)).float()
            pose = self.poses[test_idx][frame_idx]
            pose = torch.from_numpy(self.pose_vector_from_matrix(pose).reshape(-1, 1)).float()
            twist = torch.from_numpy(self.twists[test_idx][frame_idx].reshape(-1, 1)).float()
            state = torch.cat((pose, twist), dim=0).view(-1)

            # Resize frames
            h, w = self.img_size, self.img_size
            resize_frame = transforms.Resize((h, w))
            output_frame = resize_frame(output_frame.unsqueeze(0))
            input_frame = resize_frame(input_frame.unsqueeze(0))

            if not self.augmented:
                # Initialize sequence
                if seq_start:
                    in_frame_seq = torch.zeros_like(input_frame).view(1, 1, h, w).repeat((self.seq_length, 1, 1, 1))
                    out_frame_seq = torch.zeros_like(output_frame).view(1, 1, h, w).repeat((self.seq_length, 1, 1, 1))
                    state_seq = torch.zeros_like(state).unsqueeze(0).repeat((self.seq_length, 1))
                    seq_start = False

                # Update sequence
                in_frame_seq[i] = input_frame.view(1, h, w)
                out_frame_seq[i] = output_frame.view(1, h, w)
                state_seq[i] = state

            # If we want a broadcasted frame, pose, twist tensor of shape (B, T, 1, H, W, 14)
            else:
                expanded_frame = input_frame.view(1, h, w, 1)
                expanded_state = state.view(1, 1, 1, -1).repeat(1, h, w, 1)
                augmented_frame = torch.cat((expanded_frame, expanded_state), dim=-1)

                # Initialize augmented sequence
                if seq_start:
                    aug_seq_in = torch.zeros_like(augmented_frame).unsqueeze(0).repeat((self.seq_length, 1, 1, 1, 1))
                    aug_seq_out = torch.zeros_like(output_frame).unsqueeze(0).repeat((self.seq_length, 1, 1, 1))
                    seq_start = False

                # Update augmented sequence
                aug_seq_in[i] = augmented_frame
                aug_seq_out[i] = output_frame

        return {'input': (in_frame_seq, state_seq), 'target': out_frame_seq, 'indices': indices_arr} if not self.augmented else {'input': aug_seq_in, 'target': aug_seq_out, 'indices': indices_arr}

    def to_zero_pad_idx(self, idx):
        """
        Convert frame index from regular index to zero-padded index.
        e.g. 1 -> 00001, 12 -> 00012
        """
        return f'{idx:05d}'

    def pose_vector_from_matrix(self, pose):
        """
        Convert 4x4 pose matrix (as a flattenned length 16 vector) into a position and quaternion length 7 vector.
        """
        pose = pose.reshape(4, 4)
        position = pose[:3, 3].reshape(3, 1)
        orientation = pose[:3, :3]

        quat = R.from_matrix(orientation).as_quat().reshape(-1, 1)
        norm_quat = quat / np.linalg.norm(quat)

        return np.vstack((position, norm_quat)).reshape(-1)

    def to_grayscale(self, image):
        """
        Convert PNG image to grayscale mask
        """
        #TODO: Try float16 type
        return (image[..., 0] > 127).astype(np.float32)


## Submodules (Encoders, Decoders, Bottlenecks)

In [3]:
###########################################
################ Encoders #################
###########################################
class ImgEncoder(nn.Module):
    def __init__(self, in_channel=1, hidden_channel=16, out_channel=32, h_in=256, out_dim=1024):
        """
        Define convolutional neural network architecture for compressing a 256 x 256 image into a 1024 embedding vector.
        Assumes image is square.
        """
        super().__init__()

        # Compute image shape after convolution
        stride = 2
        padding_one, padding_two = 7, 1
        num_ker_one, num_ker_two = 16, 4
        h_out = (h_in + 2 * padding_one - num_ker_one) // stride + 1
        h_out = (h_out + 2 * padding_two - num_ker_two) // stride + 1

        # Define CNN
        linear_in_dim = out_channel * h_out ** 2
        linear_out_dim = out_dim
        self.conv_stack = nn.Sequential(
                            nn.Conv2d(in_channel, hidden_channel, num_ker_one, stride=stride, padding=padding_one), # h_in, w_in = (256, 256); h_out, w_out = (128, 128)
                            nn.ReLU(),
                            nn.Conv2d(hidden_channel, out_channel, num_ker_two, stride=stride, padding=padding_two), # h_in, w_in = (128, 128); h_out, w_out = (64, 64)
                            nn.ReLU(),
                            nn.Flatten(start_dim=1, end_dim=-1), # Flattens (c_out, h_out, w_out) = (32, 64, 64) into 131072
                            nn.Linear(linear_in_dim, linear_out_dim) # Encodes the 131072 length flattened convolved image into a 1024 length embedding vector
        )

    def forward(self, x):
        """
        Define neural network forward pass
        Input has shape (T, B, 1, H, W)
        """
        return self.conv_stack(x[-1])

class ImgSeqEncoder(nn.Module):
    def __init__(self, in_channel=1, hidden_channel=16, out_channel=32, h_in=256, out_dim=1024):
        """
        Define convolutional neural network architecture for compressing a SEQUENCE of 256 x 256 images into a 1024 embedding vector.
        Assumes image is square.
        """
        super().__init__()

        # Compute image shape after convolution
        stride = 2
        padding_one, padding_two = 7, 1
        num_ker_one, num_ker_two = 16, 4
        h_out = (h_in + 2 * padding_one - num_ker_one) // stride + 1
        h_out = (h_out + 2 * padding_two - num_ker_two) // stride + 1

        # Define CNN
        linear_in_dim = out_channel * h_out ** 2
        linear_out_dim = out_dim
        self.conv_stack = nn.Sequential(
                            nn.Conv2d(in_channel, hidden_channel, num_ker_one, stride=stride, padding=padding_one), # h_in, w_in = (256, 256); h_out, w_out = (128, 128)
                            nn.ReLU(),
                            nn.Conv2d(hidden_channel, out_channel, num_ker_two, stride=stride, padding=padding_two), # h_in, w_in = (128, 128); h_out, w_out = (64, 64)
                            nn.ReLU(),
                            nn.Flatten(start_dim=1, end_dim=-1), # Flattens (c_out, h_out, w_out) = (32, 64, 64) into 131072
                            nn.Linear(linear_in_dim, linear_out_dim) # Encodes the 131072 length flattened convolved image into a 1024 length embedding vector
        )

    def forward(self, x):
        """
        Define neural network forward pass
        Input has shape (T, B, 1, H, W)
        """
        seq_len, batch_size, _, h, w = x.shape
        x = x.reshape(seq_len * batch_size, 1, h, w)
        return self.conv_stack(x).reshape(seq_len, batch_size, -1)
    
class Conv3DImgSeqEncoder(nn.Module):
    def __init__(self, in_channel=1, hidden_channel=16, out_channel=32, h_in=256, out_dim=4096):
        super().__init__()

        # Image height after two conv layers with kernel size 3 and stride 2
        h_out = h_in // 8  # Assuming 2x2 stride twice
        
        self.conv_stack = nn.Sequential(
            nn.Conv3d(in_channels=in_channel, out_channels=hidden_channel,
                      kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3)),  # (B, C=1, T, H=256, W=256) --> (B, 16, T, 128, 128) 
            nn.ReLU(),

            nn.Conv3d(in_channels=hidden_channel, out_channels=out_channel,
                      kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)), # (B, 16, T, 128, 128) --> (B, 32, T, 64, 64)
            nn.ReLU(), 
            
            nn.Conv3d(in_channels=out_channel, out_channels=out_channel, 
                      kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)), # (B, 32, T, 64, 64) --> (B, 32, T, 32, 32) 
            nn.ReLU(), 
        )

        self.linear = nn.Linear(out_channel * h_out * h_out, out_dim)

    def forward(self, x):
        """
        x: (B, T, 1, H, W)
        return: (B, T, D_emb)
        """
        T, B, C, H, W = x.shape
        x = x.permute(1, 2, 0, 3, 4)  # -> (B, C=1, T, H, W)
        x = self.conv_stack(x)        # -> (B, C_out, T, H', W')
        B, C_out, T, H_out, W_out = x.shape
        x = x.permute(2, 0, 1, 3, 4)  # -> (T, B, C_out, H_out, W_out)
        x = x.reshape(T, B, -1)       # -> (T, B, C_out * H_out * W_out)
        x = self.linear(x)            # -> (T, B, D_emb)
        return x

class StateEncoder(nn.Module):
    def __init__(self, in_dim=13, out_dim=128):
        """
        Define linear layer to generate a length 128 embedding vector from length 13 pose and twist vector.
        """
        super().__init__()
        self.linear_layer = nn.Linear(in_dim, out_dim)

    def forward(self, x):
        """
        Define nueral network forward pass
        Input has shape (T, B, D) need (B, D) 
        """
        return self.linear_layer(x[-1])
    
class StateSeqEncoder(nn.Module):
    def __init__(self, in_dim=13, out_dim=128):
        """
        Define linear layer to generate a sequence of length 128 embedding vectors from length 13 pose and twist vectors.
        """
        super().__init__()
        self.linear_layer = nn.Linear(in_dim, out_dim)

    def forward(self, x):
        """
        Define nueral network forward pass
        """
        return self.linear_layer(x)


###########################################
################ Decoders #################
###########################################
class ImgDecoder(nn.Module):
    def __init__(self, in_channels=5, hidden_channels=3, out_channels=1, in_dim=512, hidden_dim=4096):
        """
        Reconstructs the image from the embedding vector. (B, T, in_dim): (B, 5, 512)
        Output image size is (256, 256).
        """
        super().__init__()

        # Compute output height
        h_in = int(np.sqrt(hidden_dim))
        stride = 1
        pad_one, pad_two = 34, 71
        ker_one, ker_two = 5, 15
        self.h_out = (h_in + 2 * pad_one - ker_one) // stride + 1
        self.h_out = (self.h_out + 2 * pad_two - ker_two) // stride + 1

        # Modules
        self.decoder = nn.Sequential(
                        nn.Linear(in_dim, hidden_dim), # (B, T, hidden_dim)
                        nn.ReLU(),
                        nn.Unflatten(dim=-1, unflattened_size=(h_in, h_in)),
                        nn.Conv2d(in_channels, hidden_channels, ker_one, stride, pad_one),
                        nn.ReLU(),
                        nn.Conv2d(hidden_channels, out_channels, ker_two, stride, pad_two),
        )

    def forward(self, x):
        """
        Input has shape (T, B, D_emb) = (5, 16, 512)
        Return predicted last image in sequence (B, 1, h_out, h_out)
        It treats the input sequence dimension as an images channels dimension and it convolves
        the image back to its original dimension.

        TODO: Consider instead of passing (T, B, D_emb) as (T * B, D_emb) and reconstructing as
        (T, B, 1, h_out, h_out), pass it as (B, T * D_emb) and reconstruct as (B, 1, h_out, h_out).
        """
        return self.decoder(x.permute(1, 0, 2)) # Convnet expects (B, T, H, W)
    
class ImgDecoder2(nn.Module):
    def __init__(self, in_channels=1, hidden_channels=32, out_channels=1, in_dim=512, hidden_dim=1024, out_size=256):
        super().__init__()
        self.out_size = out_size
        h_in = int(np.sqrt(hidden_dim))

        self.decoder = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Unflatten(dim=-1, unflattened_size=(1, h_in, h_in)),  # (B*T, 1, 32, 32)
            nn.Upsample(scale_factor=2, mode='nearest'),  # (B*T, 1, 64, 64)
            nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),  # (B*T, 32, 128, 128)
            nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),  # (B*T, 16, 256, 256)
            nn.Conv2d(hidden_channels, out_channels, kernel_size=3, padding=1),
        )

    def forward(self, x):

        seq_len, batch_size, D = x.shape
        x = x.reshape(seq_len * batch_size, D)
        return self.decoder(x).reshape(seq_len, batch_size, 1, self.out_size, self.out_size)

class Conv3DImgSeqDecoder(nn.Module):
    def __init__(self, in_dim=4096, out_channels=1, h_out=256):
        super().__init__()
        self.h_mid = h_out // 8  # 32
        self.c_mid = 32

        self.linear = nn.Linear(in_dim, self.c_mid * self.h_mid * self.h_mid)

        self.deconv_stack = nn.Sequential(
            nn.ConvTranspose3d(self.c_mid, self.c_mid, kernel_size=(1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1)),  # 32→64
            nn.ReLU(),
            nn.ConvTranspose3d(self.c_mid, 16, kernel_size=(1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1)),         # 64→128
            nn.ReLU(),
            nn.ConvTranspose3d(16, out_channels, kernel_size=(1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1)),       # 128→256
        )

    def forward(self, x):
        # x: (T, B, D)
        T, B, D = x.shape
        x = self.linear(x)                       # (T, B, C*H*W)
        x = x.view(T, B, self.c_mid, self.h_mid, self.h_mid)
        x = x.permute(1, 2, 0, 3, 4)             # (B, C, T, H, W)
        x = self.deconv_stack(x)                 # (B, 1, T, 256, 256)
        x = x.permute(2, 0, 1, 3, 4)             # (T, B, 1, 256, 256)
        return x


###########################################
################ Bottlenecks ##############
###########################################
class LSTM(nn.Module):
    def __init__(self, in_dim, hidden_dim, num_layers=1):
        super().__init__()
        self.model = nn.LSTM(in_dim, hidden_dim, num_layers=num_layers)

    def forward(self, x):
        return self.model(x)[0]
    
class MLP(nn.Module): 
    def __init__(self, in_dim, hidden_dim, num_layers=3):
        super().__init__()
        
        ### Create model 
        out_dim = hidden_dim 
        layers = [nn.Linear(in_dim, hidden_dim), nn.ReLU()]
        for _ in range(1, num_layers): 
            layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.ReLU()])
        layers.append(nn.Linear(hidden_dim, out_dim))
        self.model = nn.Sequential(*layers) 
        
    def forward(self, x):
        """
        Input has shape (B, D) and decoder Expects (1, B, D) 
        """
        x = x
        return self.model(x).unsqueeze(0)


## Loss Function 

In [4]:
# Loss functions 
class BCEDiceLoss(nn.Module):
    def __init__(self, weight_bce=0.7, weight_dice=0.3, smooth=1.0, weight=250.0):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([weight], device=device))
        self.weight_bce = weight_bce
        self.weight_dice = weight_dice
        self.smooth = smooth  # to avoid division by zero

    def forward(self, logits, targets):
        # BCEWithLogitsLoss expects raw logits
        bce_loss = self.bce(logits, targets)

        # Apply sigmoid to get probabilities
        probs = torch.sigmoid(logits)
        probs = probs.reshape(-1)
        targets = targets.reshape(-1)

        intersection = (probs * targets).sum()
        dice_score = (2. * intersection + self.smooth) / (
            probs.sum() + targets.sum() + self.smooth
        )
        dice_loss = 1 - dice_score

        return self.weight_bce * bce_loss + self.weight_dice * dice_loss

## Main Module (FilterNet)

In [5]:
class FilterNet(nn.Module):
    def __init__(self, in_dim=4224, hidden_dim=4096, seq_length=3, augmented=False):
        """
        Define LSTM architecture with image and state encoders
        Must concatenate image and state embeddings to make a 1024 + 128 length embedding vector for lstm
        lstm input dimension is then 1024 + 128 = 1152

        conv3d vals: in_dim 4224, hidden_dim 4096
        conv2d vals: in_dim 1152, hidden_dim 512
        """
        super().__init__()

        # Parameters
        self.in_dim = in_dim
        self.hidden_dim = hidden_dim
        self.seq_len = seq_length
        self.augmented = augmented

        # Conv3d approach 
        self.image_encoder = Conv3DImgSeqEncoder()
        self.state_encoder = StateSeqEncoder() 
        self.lstm = LSTM(in_dim, hidden_dim) 
        self.image_decoder = Conv3DImgSeqDecoder() 
        self.loss_fun = BCEDiceLoss()


    def loss(self, sequence):
        """
        Unless using the augmented (B, T, 1, H, W, 14) tensor, data will come in as a Tuple storing a sequence of {'input', 'target'}
        dictionaries. Each 'input' field contains a frame, a pose, and a twist, each as a tensor.

        Must loop through the sequence to generate embedding, but for small sequence lengths, the overhead is negligible, and actually
        preferable, than the memory overhead of the augmented tensor.
        """

        if self.augmented:
            # Retrieve data
            input, output = sequence['input'].permute(1, 0, 2, 3, 4, 5), sequence['target'].permute(1, 0, 2, 3, 4) # switch to sequence first
            frames = input[..., 0]
            state = input[:, :, :, 0, 0, 1:]
            out_frame = output[-1, ...]
        else:
            # Pass each element of the sequence through the model
            frames = sequence['input'][0].permute(1, 0, 2, 3, 4) # (B, T, 1, H, W) --> (T, B, 1, H, W)
            state = sequence['input'][1].permute(1, 0, 2) # (B, T, 13) --> (T, B, 13)
            out_frame = sequence['target'].permute(1, 0, 2, 3, 4) #(B, T, 1, H, W) --> (T, B, 1, H, W)

        # Pass inputs through encoders
        img_embedding = self.image_encoder(frames) # out dim should be (T, B, d_img_emb)
        state_embedding = self.state_encoder(state) # out dim should be (T, B, d_state_emb)

        # Concatenate embedding vector and reconstruct sequence as a tensor
        compressed_input = torch.cat((img_embedding, state_embedding), dim=-1) # (T, B, in_dim)

        # Pass compressed sequence through LSTM
        lstm_out = self.lstm(compressed_input) # should have shape (T, B, hidden_dim) (5, 16, 512)

        # Pass through decoder to reconstruct predicted last frame in sequence
        pred_frames = self.image_decoder(lstm_out)

        # Get loss between predicted frame and last frame in target sequence
        return self.loss_fun(pred_frames, out_frame), pred_frames


## Prepare data and model

In [6]:
# Define training parameters
num_workers = 0
num_vids = 7 
path_to_data = '/home/jrached/cv_project_code/project/data/filter_net/processed_flow'
augmented = False
batch_size = 1
seq_length = 3
img_size = 256

dataset = VideoDataset(path_to_data, augmented=augmented, num_vids=num_vids, seq_length=seq_length, img_size=img_size)

val_ratio = 0.2
val_size = int(len(dataset) * val_ratio)
train_size = len(dataset) - val_size

train_dataset = torch.utils.data.Subset(dataset, list(range(train_size)))
val_dataset = torch.utils.data.Subset(dataset, list(range(train_size, len(dataset))))

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

# Initialize model and load weights 
path_to_save = "/home/jrached/cv_project_code/cv_project/models/filternet1.pt"
model = FilterNet(seq_length=seq_length)
model = model.to(device) 
model.load_state_dict(torch.load(path_to_save, weights_only=True))
model.eval()

  0%|          | 0/7 [00:00<?, ?it/s]

100%|██████████| 7/7 [00:00<00:00, 125.58it/s]


FilterNet(
  (image_encoder): Conv3DImgSeqEncoder(
    (conv_stack): Sequential(
      (0): Conv3d(1, 16, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3))
      (1): ReLU()
      (2): Conv3d(16, 32, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))
      (3): ReLU()
      (4): Conv3d(32, 32, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))
      (5): ReLU()
    )
    (linear): Linear(in_features=32768, out_features=4096, bias=True)
  )
  (state_encoder): StateSeqEncoder(
    (linear_layer): Linear(in_features=13, out_features=128, bias=True)
  )
  (lstm): LSTM(
    (model): LSTM(4224, 4096)
  )
  (image_decoder): Conv3DImgSeqDecoder(
    (linear): Linear(in_features=4096, out_features=32768, bias=True)
    (deconv_stack): Sequential(
      (0): ConvTranspose3d(32, 32, kernel_size=(1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1))
      (1): ReLU()
      (2): ConvTranspose3d(32, 16, kernel_size=(1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1))
      (3): ReLU(

## Filter data 


In [18]:
dest_path = '/home/jrached/cv_project_code/project/data/filter_net/filtered_flow'
data_path = '/home/jrached/cv_project_code/project/data/filter_net/dataset'


# Make directories
for vid_idx in range(1, num_vids+1): 
    os.makedirs(os.path.join(dest_path, f'test{vid_idx}', 'mask'), exist_ok=True)
    os.makedirs(os.path.join(dest_path, f'test{vid_idx}', 'depth'), exist_ok=True)
    os.makedirs(os.path.join(dest_path, f'test{vid_idx}', 'intrinsics'), exist_ok=True)
    os.makedirs(os.path.join(dest_path, f'test{vid_idx}', 'poses'), exist_ok=True)

In [None]:
h, w = 480, 848
resize_frame = transforms.Resize((h, w))

is_first_seq = True 
prev_test_idx = 1 
for sequence in tqdm(train_loader): 
    # Move data to device
    frames, state = sequence['input']
    out_frame = sequence['target']
    indices = sequence['indices'].squeeze(0) # Squeeze batch dimension 

    sequence['input'] = (frames.to(device), state.to(device))
    sequence['target'] = out_frame.to(device)

    # Compute the loss and its gradients
    loss, pred  = model.loss(sequence)

    test_idx = int(indices[0][0].item())
    if prev_test_idx != test_idx: 
        is_first_seq = True 
        prev_test_idx = test_idx 

    if is_first_seq: 
        # If first sequence, save full sequence
        for i, index in enumerate(indices): 
            test_idx, frame_idx = int(index[0].item()), int(index[1].item()) 

            prob_pred_frame = torch.sigmoid(pred[i])
            out = (prob_pred_frame > 0.5).to(torch.float32) 
            output_frame = resize_frame(out[0, :, :, :])
            image = Image.fromarray((output_frame[0, :, :] * 255).cpu().detach().numpy().astype(np.uint8))
            image.save(os.path.join(dest_path, f'test{test_idx}/mask/frame{frame_idx}.png'))
            subprocess.run(['cp', os.path.join(data_path, f'test{test_idx}/depth/depth_img{frame_idx}.png'), os.path.join(dest_path, f'test{test_idx}/depth/depth_img{frame_idx}.png')])

        subprocess.run(f'cp {data_path}/test{test_idx}/poses/* {dest_path}/test{test_idx}/poses/.', shell=True)
        subprocess.run(f'cp {data_path}/test{test_idx}/intrinsics/* {dest_path}/test{test_idx}/intrinsics/.', shell=True)
        is_first_seq = False 
    else: 
        # If not first sequence, only save last element 
        pred = pred[-1] 
        test_idx, frame_idx = int(indices[-1][0].item()), int(indices[-1][1].item()) 

        prob_pred_frame = torch.sigmoid(pred)
        out = (prob_pred_frame > 0.5).to(torch.float32) 
        output_frame = resize_frame(out[0, :, :, :])
        image = Image.fromarray((output_frame[0, :, :] * 255).cpu().detach().numpy().astype(np.uint8))
        image.save(os.path.join(dest_path, f'test{test_idx}/mask/frame{frame_idx}.png'))
        subprocess.run(['cp', os.path.join(data_path, f'test{test_idx}/depth/depth_img{frame_idx}.png'), os.path.join(dest_path, f'test{test_idx}/depth/depth_img{frame_idx}.png')])


    
    

 45%|████▍     | 7630/17095 [04:14<05:04, 31.05it/s]

In [None]:
# h, w = 480, 848
# resize_frame = transforms.Resize((h, w))

# is_first_seq = True 
# prev_test_idx = 1 
# for sequence in tqdm(val_loader): 
#     # Move data to device
#     frames, state = sequence['input']
#     out_frame = sequence['target']
#     indices = sequence['indices'].squeeze(0) # Squeeze batch dimension 

#     sequence['input'] = (frames.to(device), state.to(device))
#     sequence['target'] = out_frame.to(device)

#     # Compute the loss and its gradients
#     loss, pred  = model.loss(sequence)

#     test_idx = int(indices[0][0].item())
#     if prev_test_idx != test_idx: 
#         is_first_seq = True 
#         prev_test_idx = test_idx 

#     if is_first_seq: 
#         # If first sequence, save full sequence
#         for i, index in enumerate(indices): 
#             test_idx, frame_idx = int(index[0].item()), int(index[1].item()) 

#             prob_pred_frame = torch.sigmoid(pred[i])
#             out = (prob_pred_frame > 0.5).to(torch.float32) 
#             output_frame = resize_frame(out[0, :, :, :])
#             image = Image.fromarray((output_frame[0, :, :] * 255).cpu().detach().numpy().astype(np.uint8))
#             image.save(os.path.join(dest_path, f'test{test_idx}/mask/frame{frame_idx}.png'))
#             subprocess.run(['cp', os.path.join(data_path, f'test{test_idx}/depth/depth_img{frame_idx}.png'), os.path.join(dest_path, f'test{test_idx}/depth/depth_img{frame_idx}.png')])

#         subprocess.run(f'cp {data_path}/test{test_idx}/poses/* {dest_path}/test{test_idx}/poses/.', shell=True)
#         subprocess.run(f'cp {data_path}/test{test_idx}/intrinsics/* {dest_path}/test{test_idx}/intrinsics/.', shell=True)
#         is_first_seq = False 
#     else: 
#         # If not first sequence, only save last element 
#         pred = pred[-1] 
#         test_idx, frame_idx = int(indices[-1][0].item()), int(indices[-1][1].item()) 

#         prob_pred_frame = torch.sigmoid(pred)
#         out = (prob_pred_frame > 0.5).to(torch.float32) 
#         output_frame = resize_frame(out[0, :, :, :])
#         image = Image.fromarray((output_frame[0, :, :] * 255).cpu().detach().numpy().astype(np.uint8))
#         image.save(os.path.join(dest_path, f'test{test_idx}/mask/frame{frame_idx}.png'))
#         subprocess.run(['cp', os.path.join(data_path, f'test{test_idx}/depth/depth_img{frame_idx}.png'), os.path.join(dest_path, f'test{test_idx}/depth/depth_img{frame_idx}.png')])


    
    

100%|██████████| 4273/4273 [02:48<00:00, 25.34it/s]
