In [None]:
#Import PyTorch and matplotlib
import torch
import torchvision
from torch import nn
from torchvision import transforms
from torch.utils.data import Dataset
import imageio.v3 as iio
import numpy as np
import matplotlib.pyplot as plt
import time
from pthflops import count_ops

vRead = iio.imread("c_elegans.mp4")
video = np.array(vRead)

#Check PyTorch version
torch.__version__


### SET GPU if possible
run in shell: 
CUDA_VISIBLE_DEVICES= {gpu#/#s}

In [None]:
print("Cuda available: ", torch.cuda.is_available())
if(torch.cuda.is_available()):

    print("Is cuDNN version:", torch.backends.cudnn.version())

    print("cuDNN enabled: ", torch.backends.cudnn.enabled)

    print("Device count: ", torch.cuda.device_count())

    print("Current device: ", torch.cuda.current_device())

    print("Device name: ", torch.cuda.get_device_name(torch.cuda.current_device()))
#Setup device agnostic code (i.e use GPU if possible)
device = "cuda" if torch.cuda.is_available() else "cpu"
gpuNum = 1
print(device)

In [None]:
#Video Metadata
import imageio.v3 as iio
props = iio.improps("c_elegans.mp4")
print("Shape (frames, w, h, RGB): \n" + str(props.shape))
print(props.dtype)

### Encoding Image as a Tensor

In [None]:
#Input grid  (width_px, height_ px)
grid = torch.empty(props.shape[1], props.shape[2], 3).to(device)

# Create image tensor
frame = 0
image = torch.tensor(video[frame]).to(device)

image.shape


In [None]:
#test original identity output
plt.imshow(image.cpu())
plt.axis(False)
plt.title("Frame: " + str(frame));

### Dataset and DataLoader

In [None]:
import os
from torch.utils.data import DataLoader

class SingleImageDataset(Dataset):
    def __init__(self, image, transform=None, target_transform=None):
        self.image = image
        self.transform = transform
        self.target_transform = target_transform
    def __len__(self):
        return int(image.shape[0]) * int(image.shape[1])
    def __getitem__(self, idx):
        row = idx // int(image.shape[1])
        col = idx % int(image.shape[1])
        pixel = image[row][col]
        #label = pixel
        return row, col, pixel
training_data = SingleImageDataset(image)
#testing_data = None

train_dataloader = DataLoader(training_data, batch_size=32, shuffle=True)
#train_dataloader = DataLoader(testing_data, batch_size=32, shuffle=True)


In [45]:
#Test Dataloader
testGrid = torch.empty(322, 344, 3).cpu()
for batch in iter(train_dataloader):
    print(batch)
    
plt.imshow(testGrid.cpu())

### Main Model

In [None]:
import torch
from torch import nn
#from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from torchvision import transforms

#Use seed for reproducibility
torch.manual_seed(42)
#Multilayer Percepetron Model 
class MLP(nn.Module):
    def __init__(self,
                 input_shape: int,
                 hidden_units: int,
                 output_shape: int):
        super().__init__()
        self.layer_stack = nn.Sequential(
        nn.Linear(input_shape, hidden_units),
        nn.ReLU(),
        nn.Linear(hidden_units, output_shape),
        )
    #forward reconstruction
    def forward(self, X):
        return self.layer_stack(X.to(device))

### Instance of Model (for selected frame/image)

In [None]:
#Create an Instance and set loss function & optimizer
model_0 = MLP(input_shape=3, 
              hidden_units=128, 
              output_shape=3).to(device)

loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(params=model_0.parameters(),
                             lr=0.001)
#list(model_0.parameters())

## Check Size of Model

In [None]:
# Check model size by summing parameters and state_dict
params_size = 0
for param in model_0.parameters():
    params_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in model_0.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

size_all_mb = (params_size + buffer_size) / 1024**2
imageSize = 8 * image.shape[0] * image.shape[1] * image.shape[2]
imageSizeMB = imageSize / (10**6)
perDecrease = (imageSizeMB - size_all_mb) / imageSizeMB
perDecrease *= 100
print('original image size(no compression): {:.3f}MB'.format(imageSizeMB))
print('model size: {:.3f}MB'.format(size_all_mb))
print('Percent decrease in memory size: {:.3f}%'.format(perDecrease))