In [1]:
import torch
torch._dynamo.config.suppress_errors = True
torch._dynamo.config.verbose = False
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f'Using {device} device')

Using cuda:0 device


## ConvLSTMCell

https://github.com/ndrplz/ConvLSTM_pytorch

In [3]:
import torch.nn as nn
import torch


class ConvLSTMCell(nn.Module):

    def __init__(self, input_dim, hidden_dim, kernel_size, bias):
        """
        Initialize ConvLSTM cell.

        Parameters
        ----------
        input_dim: int
            Number of channels of input tensor.
        hidden_dim: int
            Number of channels of hidden state.
        kernel_size: (int, int)
            Size of the convolutional kernel.
        bias: bool
            Whether or not to add the bias.
        """

        super(ConvLSTMCell, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.kernel_size = kernel_size
        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias

        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state

        combined = torch.cat([input_tensor, h_cur], dim=1)  # concatenate along channel axis

        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, c_next

    def init_hidden(self, batch_size, image_size):
        height, width = image_size
        return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
                torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))


class ConvLSTM(nn.Module):

    """

    Parameters:
        input_dim: Number of channels in input
        hidden_dim: Number of hidden channels
        kernel_size: Size of kernel in convolutions
        num_layers: Number of LSTM layers stacked on each other
        batch_first: Whether or not dimension 0 is the batch or not
        bias: Bias or no bias in Convolution
        return_all_layers: Return the list of computations for all layers
        Note: Will do same padding.

    Input:
        A tensor of size B, T, C, H, W or T, B, C, H, W
    Output:
        A tuple of two lists of length num_layers (or length 1 if return_all_layers is False).
            0 - layer_output_list is the list of lists of length T of each output
            1 - last_state_list is the list of last states
                    each element of the list is a tuple (h, c) for hidden state and memory
    Example:
        >> x = torch.rand((32, 10, 64, 128, 128))
        >> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)
        >> _, last_states = convlstm(x)
        >> h = last_states[0][0]  # 0 for layer index, 0 for h index
    """

    def __init__(self, input_dim, hidden_dim, kernel_size, num_layers,
                 batch_first=False, bias=True, return_all_layers=False):
        super(ConvLSTM, self).__init__()

        self._check_kernel_size_consistency(kernel_size)

        # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
        kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
        hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
        if not len(kernel_size) == len(hidden_dim) == num_layers:
            raise ValueError('Inconsistent list length.')

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.num_layers = num_layers
        self.batch_first = batch_first
        self.bias = bias
        self.return_all_layers = return_all_layers

        cell_list = []
        for i in range(0, self.num_layers):
            cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]

            cell_list.append(ConvLSTMCell(input_dim=cur_input_dim,
                                          hidden_dim=self.hidden_dim[i],
                                          kernel_size=self.kernel_size[i],
                                          bias=self.bias))

        self.cell_list = nn.ModuleList(cell_list)

    def forward(self, input_tensor, hidden_state=None):
        """

        Parameters
        ----------
        input_tensor: todo
            5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w)
        hidden_state: todo
            None. todo implement stateful

        Returns
        -------
        last_state_list, layer_output
        """
        if not self.batch_first:
            # (t, b, c, h, w) -> (b, t, c, h, w)
            input_tensor = input_tensor.permute(1, 0, 2, 3, 4)

        b, _, _, h, w = input_tensor.size()

        # Implement stateful ConvLSTM
        if hidden_state is not None:
            raise NotImplementedError()
        else:
            # Since the init is done in forward. Can send image size here
            hidden_state = self._init_hidden(batch_size=b,
                                             image_size=(h, w))

        layer_output_list = []
        last_state_list = []

        seq_len = input_tensor.size(1)
        cur_layer_input = input_tensor

        for layer_idx in range(self.num_layers):

            h, c = hidden_state[layer_idx]
            output_inner = []
            for t in range(seq_len):
                h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :],
                                                 cur_state=[h, c])
                output_inner.append(h)

            layer_output = torch.stack(output_inner, dim=1)
            cur_layer_input = layer_output

            layer_output_list.append(layer_output)
            last_state_list.append([h, c])

        if not self.return_all_layers:
            layer_output_list = layer_output_list[-1:]
            last_state_list = last_state_list[-1:]

        return layer_output_list, last_state_list

    def _init_hidden(self, batch_size, image_size):
        init_states = []
        for i in range(self.num_layers):
            init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))
        return init_states

    @staticmethod
    def _check_kernel_size_consistency(kernel_size):
        if not (isinstance(kernel_size, tuple) or
                (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
            raise ValueError('`kernel_size` must be tuple or list of tuples')

    @staticmethod
    def _extend_for_multilayer(param, num_layers):
        if not isinstance(param, list):
            param = [param] * num_layers
        return param

## Model

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvLSTMNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dims, kernel_size, num_layers):
        super(ConvLSTMNetwork, self).__init__()
        self.convlstm = ConvLSTM(input_dim=input_dim,
                                 hidden_dim=hidden_dims,
                                 kernel_size=kernel_size,
                                 num_layers=num_layers,
                                 batch_first=True,
                                 bias=True,
                                 return_all_layers=False)
        # The last hidden layer's output is used for final image prediction
        final_hidden_dim = hidden_dims[-1]
        self.final_conv = nn.Conv2d(final_hidden_dim, 1, kernel_size=1)  # Output 1 channel image

    def forward(self, x):
        # x should be of shape (batch, sequence, channels, height, width)
        lstm_out, _ = self.convlstm(x)  # Get output from ConvLSTM
        # lstm_out is a list where we need only the last output
        last_output = lstm_out[-1][:, -1, :, :, :]  # Take the output of the last time step
        final_output = self.final_conv(last_output)  # Convert to the required output image
        return final_output

# Parameters for Model Architecture:
channels = 1  # Grayscale images
hidden_dims = [64, 64, 128]  # Hidden dimensions for each ConvLSTM layer
kernel_size = (3, 3)  # Kernel size for ConvLSTM cells
num_layers = 3  # Number of layers in ConvLSTM

model = ConvLSTMNetwork(input_dim=channels,
                        hidden_dims=hidden_dims,
                        kernel_size=kernel_size,
                        num_layers=num_layers)
model.to(device)

ConvLSTMNetwork(
  (convlstm): ConvLSTM(
    (cell_list): ModuleList(
      (0): ConvLSTMCell(
        (conv): Conv2d(65, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (1): ConvLSTMCell(
        (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (2): ConvLSTMCell(
        (conv): Conv2d(192, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
  )
  (final_conv): Conv2d(128, 1, kernel_size=(1, 1), stride=(1, 1))
)

In [5]:
# Dummy input data: Batch size of 1, 5 time steps, 1 channel, 100x100 pixels
input_tensor = torch.rand(1, 5, 1, 50, 50).to(device)
# Forward pass through the model
output_image = model(input_tensor)
print("Output image shape:", output_image.shape)  # Expected shape: (batch, channels, height, width)

Output image shape: torch.Size([1, 1, 50, 50])


In [6]:
def print_model_parameters(model):
    total_params = 0
    print("Layer-wise parameters:\n")
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        param = parameter.numel()
        print(f"{name}: {param}")
        total_params += param
    print(f"\nTotal parameters: {total_params}")

print_model_parameters(model)

Layer-wise parameters:

convlstm.cell_list.0.conv.weight: 149760
convlstm.cell_list.0.conv.bias: 256
convlstm.cell_list.1.conv.weight: 294912
convlstm.cell_list.1.conv.bias: 256
convlstm.cell_list.2.conv.weight: 884736
convlstm.cell_list.2.conv.bias: 512
final_conv.weight: 128
final_conv.bias: 1

Total parameters: 1330561


In [7]:
#from torchsummary import summary
#summary(model, input_size=(5, 1, 100, 100)) # the summary call breaks the model

!pip install torchinfo
from torchinfo import summary # older depricated 'torchinfo' works

# Assuming the model and input_tensor are defined as shown previously
summary(model, input_sizes=(1, 5, 1, 400, 300))


Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


Layer (type:depth-idx)                   Param #
ConvLSTMNetwork                          --
├─ConvLSTM: 1-1                          --
│    └─ModuleList: 2-1                   --
│    │    └─ConvLSTMCell: 3-1            150,016
│    │    └─ConvLSTMCell: 3-2            295,168
│    │    └─ConvLSTMCell: 3-3            885,248
├─Conv2d: 1-2                            129
Total params: 1,330,561
Trainable params: 1,330,561
Non-trainable params: 0

## Data

In [8]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

In [9]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [10]:
# Copy zip file from Google Drive to local Colab env. and unzip
!cp "/content/drive/My Drive/game_frames.zip" "/content/game_frames.zip"
!unzip -q "/content/game_frames.zip" -d "/content/game_frames"
!ls "/content/game_frames"

game_frames


In [None]:
class ImageSequenceDataset(Dataset):
    def __init__(self, root_dir, transform=None, sequence_length=6):
        """
        Args:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
            sequence_length (int): Number of images in each sequence.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.sequence_length = sequence_length
        self.image_filenames = [f for f in sorted(os.listdir(root_dir)) if f.endswith('.jpg')]

    def __len__(self):
        # Return the number of possible sequences
        return len(self.image_filenames) - (self.sequence_length - 1)

    def __getitem__(self, idx):
        images = []
        for i in range(self.sequence_length):
            img_name = os.path.join(self.root_dir, self.image_filenames[idx + i])
            image = Image.open(img_name).convert('L')  # Convert to grayscale
            if self.transform:
                image = self.transform(image)
            images.append(image)

        # Stack images to create a sequence tensor
        # Assumes that images are transformed to tensors by `transforms`
        sequence = torch.stack(images[:-1])  # All but last for input sequence
        target = images[-1]  # Last image as ground truth
        return sequence, target

# Transform to tensor and resize if necessary
transform = transforms.Compose([
    transforms.Resize((100, 100)),  # Resize all images to the same size
    transforms.ToTensor(),  # Convert images to tensor
    # If BCEWithLogitsLoss do not use normalization
    #transforms.Normalize((0.5,), (0.5,))  # Normalize images; mean and std are tuples with one value per channel
])

In [21]:
dataset = ImageSequenceDataset('/content/game_frames/game_frames', transform=transform)
dataloader = DataLoader(dataset, batch_size=16, shuffle=False)  # Set `shuffle=False` to maintain sequence order !!!

In [13]:
from torch.utils.data import DataLoader
for sequences, targets in dataloader:
    print("Batch of sequences shape:", sequences.shape)
    print("Batch of sequences type:", sequences.dtype)
    print("Batch of targets shape:", targets.shape)
    print("Batch of targets type:", targets.dtype)
    break

Batch of sequences shape: torch.Size([16, 5, 1, 100, 100])
Batch of sequences type: torch.float32
Batch of targets shape: torch.Size([16, 1, 100, 100])
Batch of targets type: torch.float32


## Training

In [14]:
import torch.optim as optim
# Loss function
#criterion = torch.nn.MSELoss() # mean square error
#criterion = torch.nn.L1Loss() # average of absolute differences between targets and predictions
#criterion = torch.nn.HuberLoss() # quadratic for small errors and linear for large errors
criterion = torch.nn.BCEWithLogitsLoss()

# https://discuss.pytorch.org/t/how-to-calculate-loss-for-binary-images/158987/6

In [15]:
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [16]:
from tqdm.notebook import tqdm


def train_model(model, dataloader, criterion, optimizer, num_epochs=10):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    model.train()  # Set the model to training mode

    for epoch in range(num_epochs):
        running_loss = 0.0
        progress_bar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=False)

        for sequences, targets in progress_bar:
            sequences = sequences.to(device)  # Shape: [batch_size, seq_length-1, channels, height, width]
            targets = targets.to(device)  # Shape: [batch_size, channels, height, width]

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(sequences)
            loss = criterion(outputs, targets)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            # Update running loss
            running_loss += loss.item()

            # Update the progress bar description with the latest loss
            progress_bar.set_description(f"Epoch {epoch+1}/{num_epochs} - Loss: {running_loss / (progress_bar.n + 1)}")

        # Calculate average loss for the epoch
        epoch_loss = running_loss / len(dataloader)
        print(f"Epoch {epoch+1}, Average Loss: {epoch_loss}")

    print('Finished Training')

In [None]:
train_model(model, dataloader, criterion, optimizer, num_epochs=1)

Epoch 1/1:   0%|          | 0/625 [00:00<?, ?it/s]

## Save the Model

In [26]:
# Save the entire model
torch.save(model, 'dx_ball.pt')
# Save only the state dictionary (model weights)
torch.save(model.state_dict(), 'dx_ball_weights.pt')

In [25]:
# Save the entire model
torch.save(model, '/content/drive/My Drive/dx_ball.pt')
# Save only the state dictionary (model weights)
torch.save(model.state_dict(), '/content/drive/My Drive/dx_ball_weights.pt')