# end 2 end model

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class HParams:
    def __init__(self):
        self.num_measurements = 1500
        self.layer_sizes = [50, 200]
        self.train_batch_size = 8
        self.test_batch_size = 10
        self.learning_rate = 0.001
        self.max_train_steps = 15000   # 50000
        self.summary_iter = 1000
        self.checkpoint_iter = 20000
        self.is_A_trainable = False
        self.noise_std = 0.1
        self.dataset = "mmnist"           # mmnist or ucf

hparams = HParams()

hparams.dataset = "mmnist"
hparams.num_measurements = 300   # for different measurements, train new model 
    

In [None]:
# Moving MNIST model 
class E2EAutoencoder(nn.Module):
    def __init__(self, hparams):
        super().__init__()
        self.input_dim = 64 * 64  
        self.A = nn.Linear(self.input_dim, hparams.num_measurements, bias=False)
        self.A.weight.data.normal_(std=1.0/hparams.num_measurements)
        self.A.weight.requires_grad = hparams.is_A_trainable
        
        layers = []
        prev_size = hparams.num_measurements
        for size in hparams.layer_sizes:
            layers.append(nn.Linear(prev_size, size))
            layers.append(nn.ReLU())
            prev_size = size
        self.encoder = nn.Sequential(*layers)
        
        self.decoder = nn.Sequential(
            nn.Linear(prev_size, self.input_dim),
            nn.Sigmoid()
        )
        self.noise_std = hparams.noise_std

    def forward(self, x):
        x_flat = x.view(x.size(0),1, -1)  # [B*T, 64*64]
        y = self.A(x_flat)              
    
        if not self.training:  # add noise only test
            noise = torch.normal(mean=0, std=self.noise_std, size=y.size(), device=y.device)
            y += noise
        hidden = self.encoder(y)
        recon = self.decoder(hidden)    
        return recon.view_as(x)         

In [None]:
# UCF model 
import torch
import torch.nn as nn

class E2EAutoencoderUCF(nn.Module):
    def __init__(self, hparams):
        super().__init__()
        self.input_dim = 3 * 64 * 64  
        self.A = nn.Linear(self.input_dim, hparams.num_measurements, bias=False)
        
        self.A.weight.data = torch.randn_like(self.A.weight.data)
        self.A.weight.requires_grad = hparams.is_A_trainable
        
        layers = []
        prev_size = hparams.num_measurements
        for size in hparams.layer_sizes:
            layers.append(nn.Linear(prev_size, size))
            layers.append(nn.ReLU())
            prev_size = size
        self.encoder = nn.Sequential(*layers)
        
        self.decoder = nn.Sequential(
            nn.Linear(prev_size, self.input_dim),
            nn.Sigmoid()
        )
        
        self.noise_std = hparams.noise_std

    def forward(self, x):

        x_flat = x.view(x.size(0), -1)  # [B, 3*64*64]
        y = self.A(x_flat)     
        
        if not self.training:  
            noise = torch.normal(mean=0, std=self.noise_std, size=y.size(), device=y.device)
            y += noise
        hidden = self.encoder(y)
        recon = self.decoder(hidden)    
        return recon.view_as(x)       

In [None]:
# load Moving Mnist data
import torch
from MovingMNISTVideoLoader import MovingMNIST


train_set = MovingMNIST(root='data/Movingmnist', start=0, end=10, train=True, download=True)
test_set = MovingMNIST(root='data/Movingmnist', start=0, end=10, train=False, download=True)


train_loader_mmnist = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=hparams.train_batch_size,
                 shuffle=True)
test_loader_mmnist = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=hparams.test_batch_size,
                shuffle=True)

print('==>>> total trainning batch number: {}'.format(len(train_loader_mmnist)))
print('==>>> total testing batch number: {}'.format(len(test_loader_mmnist)))

In [None]:
 # load ucf
import torch
from torch.utils.data import DataLoader, random_split

from src.UCF_handler import UCF101VideoDataset
dataset = UCF101VideoDataset(video_folder="VideoGeneration-PyTorch-main/data/UCF101", transform=None, start=0, end=10)

train_size = int(0.9 * len(dataset)) 
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# DataLoader
train_loader_ucf = DataLoader(train_dataset, batch_size=hparams.train_batch_size, shuffle=True)
test_loader_ucf = DataLoader(test_dataset, batch_size=hparams.test_batch_size, shuffle=False)

In [None]:
if hparams.dataset == "mmnist":
    model = E2EAutoencoder(hparams).to(device) 
    train_loader, test_loader = train_loader_mmnist, test_loader_mmnist
else:
    model = E2EAutoencoderUCF(hparams).to(device) 
    train_loader, test_loader = train_loader_ucf, test_loader_ucf
    
criterion = nn.BCELoss()  # sigmoid
optimizer = optim.Adam(model.parameters(), lr=hparams.learning_rate)


def train():
    model.train()
    for step in range(hparams.max_train_steps):
        #  [B, T, C, H, W] -> [B*T, C, H, W]  
        if hparams.dataset=="mmnist":
            data, _ = next(iter(train_loader))
            data = data.view(-1, 1, 64, 64).to(device)  # [B*T, 1, 64, 64]
        else:
            data= next(iter(train_loader))
            data = data.view(-1, 3, 64, 64).to(device)  # [B*T, 1, 64, 64]
        
        # [-1, 1]-> [0, 1]
        data = (data + 1) / 2
        
        optimizer.zero_grad()
        recon = model(data)
        loss = criterion(recon, data)
        loss.backward()
        optimizer.step()
        
        if step % hparams.summary_iter == 0:
            print(f"Step {step}, Loss: {loss.item():.4f}")
        
    # save
    torch.save(model.state_dict(), f'e2e_movingmnist_m300model.pth')

train()

In [None]:
import matplotlib.pyplot as plt

# test
model = E2EAutoencoderUCF(hparams).to(device)
model.load_state_dict(torch.load("e2e_ucf_m2000model.pth", map_location=device))
def test():
    model.eval()
    with torch.no_grad():
        data = next(iter(test_loader))
    
        data = (data + 1) / 2  
        print(data.shape)
        data = data.view(-1, 3, 64, 64)  
        data = data.to(device)
        print(data.shape)

        recon = model(data)
        test_loss = criterion(recon, data)
        print(f"Test Loss: {test_loss.item():.4f}")

        fig, axes = plt.subplots(2, 10, figsize=(20, 4))

        for i in range(min(30, 10*hparams.test_batch_size)):  
            axes[0, i].imshow(data[i].cpu().detach().numpy().transpose(1, 2, 0))  
            axes[0, i].axis("off")
            axes[1, i].imshow(recon[i].cpu().detach().numpy().transpose(1, 2, 0)) 
            axes[1, i].axis("off")
        plt.show()
        
test()