In [12]:
import numpy as np

def generate_sequence(img_size=64, seq_length=10):
    # Initial position
    x, y = np.random.randint(0, img_size, 2)
    
    # Random direction (dx, dy) can be -1, 0, 1
    dx, dy = np.random.choice([-1, 0, 1], 2)
    
    # Ensure it's not stationary
    while (dx, dy) == (0, 0):
        dx, dy = np.random.choice([-1, 0, 1], 2)
    
    sequence = []
    for _ in range(seq_length):
        img = np.zeros((img_size, img_size))
        img[x, y] = 1
        sequence.append(img)
        
        x = (x + dx) % img_size  # Wrap-around on the x-axis
        y = (y + dy) % img_size  # Wrap-around on the y-axis
    
    return np.array(sequence)


In [13]:
class CircularPad3d(nn.Module):
    def __init__(self, padding):
        super(CircularPad3d, self).__init__()
        self.padding = padding
        
    def forward(self, x):
        # Padding: (left, right, top, bottom, front, back)
        left, right, top, bottom, front, back = self.padding
        x = torch.cat((x[:, :, :, :, -front:], x, x[:, :, :, :, :back]), dim=4)
        x = torch.cat((x[:, :, :, -top:], x, x[:, :, :, :bottom]), dim=3)
        x = torch.cat((x[:, :, -left:], x, x[:, :, :right]), dim=2)
        return x

# Adjust the MotionPredictor model:
class MotionPredictor(nn.Module):
    def __init__(self):
        super(MotionPredictor, self).__init__()
        self.pad = CircularPad3d((1, 1, 1, 1, 1, 1))
        self.conv1 = nn.Conv3d(1, 16, kernel_size=(3, 3, 3), stride=1)
        self.conv2 = nn.Conv3d(16, 32, kernel_size=(3, 3, 3), stride=1)
        self.fc1 = nn.Linear(32 * 60 * 60, 64 * 64)

        
    def forward(self, x):
        x = self.pad(x)
        x = self.conv1(x)
        x = nn.ReLU()(x)
        
        x = self.pad(x)
        x = self.conv2(x)
        x = nn.ReLU()(x)
        
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x.view(x.size(0), 64, 64)


In [14]:
import os

num_sequences = 100
sequence_length = 10
data_dir = 'training_data/'

# Create directory if it doesn't exist
if not os.path.exists(data_dir):
    os.makedirs(data_dir)

for i in range(num_sequences):
    seq = generate_sequence(seq_length=sequence_length)
    np.save(os.path.join(data_dir, f'sequence_{i}.npy'), seq)


In [15]:
import torch
import torch.optim as optim

sequences = []

for i in range(num_sequences):
    seq = np.load(os.path.join(data_dir, f'sequence_{i}.npy'))
    sequences.append(seq)

sequences = np.array(sequences)


# Hyperparameters
learning_rate = 0.001
epochs = 1000
batch_size = 32

# Initialize model and optimizer
model = MotionPredictor()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(epochs):
    epoch_loss = 0.0
    for i in range(0, num_sequences, batch_size):
        # Get batch from sequences
        seqs = sequences[i: i + batch_size]
        input_seqs = torch.tensor(seqs[:, :-1], dtype=torch.float32).unsqueeze(1)
        target_seqs = torch.tensor(seqs[:, 1:], dtype=torch.float32).unsqueeze(1)
        
        optimizer.zero_grad()
        outputs = model(input_seqs)
        loss = criterion(outputs, target_seqs)
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss/(num_sequences/batch_size)}")

torch.save(model.state_dict(), "motion_predictor_model.pth")


RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x1179648 and 131072x4096)

In [None]:
import matplotlib.pyplot as plt

# Load the model
model = MotionPredictor()
model.load_state_dict(torch.load("motion_predictor_model.pth"))
model.eval()

# Generate a sequence for validation
seq = generate_sequence(seq_length=sequence_length)
input_seq = torch.tensor(seq[:-1], dtype=torch.float32).unsqueeze(0).unsqueeze(1)
predicted_seq = model(input_seq)

# Visualizing
fig, axs = plt.subplots(2, sequence_length, figsize=(20, 4))
for i in range(sequence_length):
    if i != sequence_length - 1:
        axs[0, i].imshow(seq[i], cmap='gray')
        axs[1, i].imshow(predicted_seq[0, i].detach().numpy(), cmap='gray')
    else:
        axs[0, i].imshow(seq[i], cmap='gray')
        axs[1, i].set_title('Predicted Next Frame')
        axs[1, i].imshow(predicted_seq[0, -1].detach().numpy(), cmap='gray')
    
    axs[0, i].axis('off')
    axs[1, i].axis('off')

plt.tight_layout()
plt.show()
