In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os

In [65]:
import torch.nn as nn
import torch
import numpy as np

class Conv3DNet(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size, sequence_length):
        super(Conv3DNet, self).__init__()
        
        self.padding = kernel_size // 2
        
        self.conv1 = nn.Conv3d(input_channels, hidden_channels, kernel_size, padding=self.padding, bias=True)
        self.conv2 = nn.Conv3d(hidden_channels, hidden_channels, kernel_size, padding=self.padding, bias=True)
        self.conv3 = nn.Conv3d(hidden_channels, input_channels, (1, kernel_size, kernel_size), padding=self.padding, bias=True)
        
    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = self.conv3(x)
        return x

def generate_data(batch_size, sequence_length, image_size, max_step_size=2):
    directions = np.random.randint(0, 4, batch_size)  # 0: right, 1: down, 2: left, 3: up
    
    # Initialize all sequences to start from the central pixel
    start_x = image_size // 2
    start_y = image_size // 2
    start_points = [(start_x, start_y) for _ in range(batch_size)]
    
    sequences = np.zeros((batch_size, 1, sequence_length, image_size, image_size))
    
    for b in range(batch_size):
        x, y = start_points[b]
        dx, dy = 0, 0
        
        if directions[b] == 0: dx = max_step_size
        if directions[b] == 1: dy = max_step_size
        if directions[b] == 2: dx = -max_step_size
        if directions[b] == 3: dy = -max_step_size
        
        for t in range(sequence_length):
            sequences[b, 0, t, y, x] = 1.0
            x = np.clip(x + dx, 0, image_size - 1)  # No toroidal looping, clamp at edges
            y = np.clip(y + dy, 0, image_size - 1)  # No toroidal looping, clamp at edges
    
    return torch.tensor(sequences, dtype=torch.float32)

def save_data_to_disk(data, filename):
    """Save the generated data tensor to disk."""
    torch.save(data, filename)

def load_data_from_disk(filename):
    """Load the data tensor from disk."""
    return torch.load(filename)



In [66]:
# Hyperparameters
batch_size = 128
sequence_length = 8
image_size = 16
input_channels = 1  # Grayscale images
hidden_channels = 64
kernel_size = 3
lr = 0.001
num_epochs = 100
print_interval = 10
data_filename = "training_data.pt"

# Initialize model, loss, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Conv3DNet(input_channels, hidden_channels, kernel_size, sequence_length).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

In [67]:
sequences = generate_data(batch_size, sequence_length + 1, image_size).to(device)
save_data_to_disk(sequences, data_filename)

In [68]:
# Check if data exists on disk, otherwise generate and save
if os.path.exists(data_filename):
    sequences = load_data_from_disk(data_filename).to(device)
else:
    sequences = generate_data(batch_size, sequence_length + 1, image_size).to(device)
    save_data_to_disk(sequences, data_filename)

inputs = sequences[:, :, :sequence_length]  # First 8 frames
targets = sequences[:, :, sequence_length]  # The 9th frame is our target

# Training loop
for epoch in range(num_epochs):
    model.train()

    # Forward pass
    outputs = model(inputs)
    
    # Loss and optimization
    loss = criterion(outputs, targets)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Print updates
    if (epoch + 1) % print_interval == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

print("Training finished!")
# Save the trained model to disk
torch.save(model.state_dict(), "model_weights.pt")


  return F.mse_loss(input, target, reduction=self.reduction)


Epoch [10/100], Loss: 0.0042
Epoch [20/100], Loss: 0.0040
Epoch [30/100], Loss: 0.0040
Epoch [40/100], Loss: 0.0039
Epoch [50/100], Loss: 0.0038
Epoch [60/100], Loss: 0.0038
Epoch [70/100], Loss: 0.0038
Epoch [80/100], Loss: 0.0038
Epoch [90/100], Loss: 0.0037
Epoch [100/100], Loss: 0.0037
Training finished!


In [70]:
import matplotlib.pyplot as plt

def plot_images(sequence, predicted, image_size=32):
    ''' Visualize sequences of images and the predicted next frame. '''
    seq_length = sequence.shape[1]
    
    for idx in range(sequence.shape[0]):  # Loop over the batch
        fig, axarr = plt.subplots(1, seq_length+1, figsize=(20, 2))
        
        for t in range(seq_length):
            axarr[t].imshow(sequence[idx, t, 0].cpu().numpy(), cmap='gray', vmin=0, vmax=1)
            axarr[t].axis('off')
            
            if t == 0:
                axarr[t].set_title("Input Sequence")
            
        axarr[seq_length].imshow(predicted[idx, 0].cpu().detach().numpy(), cmap='gray', vmin=0, vmax=1)
        axarr[seq_length].axis('off')
        axarr[seq_length].set_title("Predicted")
        
        plt.show()



# Load the model from disk
model_path = "model_weights.pt"
model = Conv3DNet(input_channels, hidden_channels, kernel_size, sequence_length).to(device)
model.load_state_dict(torch.load(model_path))
model.eval()

# Generate some test sequences
test_batch_size = 5
test_sequences = generate_data(test_batch_size, sequence_length+1, image_size).to(device)
inputs = test_sequences[:, :, :sequence_length, :, :]
targets = test_sequences[:, :, sequence_length, :, :]


# Predict the next frame using the model
with torch.no_grad():
    predictions = model(inputs)

# Plot the sequences and the predicted frames
plot_images(inputs, predictions)


IndexError: index 8 is out of bounds for dimension 1 with size 1