In [1]:
!git clone https://github.com/lisiyao21/AnimeInterp.git
%cd AnimeInterp

Cloning into 'AnimeInterp'...
remote: Enumerating objects: 133, done.[K
remote: Counting objects: 100% (61/61), done.[K
remote: Compressing objects: 100% (32/32), done.[K
remote: Total 133 (delta 43), reused 31 (delta 29), pack-reused 72[K
Receiving objects: 100% (133/133), 2.23 MiB | 5.98 MiB/s, done.
Resolving deltas: 100% (46/46), done.
/content/AnimeInterp


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
!cp "/content/drive/MyDrive/atd_12k.zip" "/content/AnimeInterp/datas/"


In [4]:
!unzip "/content/AnimeInterp/datas/atd_12k.zip" -d "/content/AnimeInterp/datas/"


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
   creating: /content/AnimeInterp/datas/datasets/train_10k/Disney_27_22670_s1/
  inflating: /content/AnimeInterp/datas/datasets/train_10k/Disney_27_22670_s1/frame2.jpg  
  inflating: /content/AnimeInterp/datas/datasets/train_10k/Disney_27_22670_s1/frame3.jpg  
  inflating: /content/AnimeInterp/datas/datasets/train_10k/Disney_27_22670_s1/frame1.jpg  
   creating: /content/AnimeInterp/datas/datasets/train_10k/Disney_v2_9_01191_s2/
  inflating: /content/AnimeInterp/datas/datasets/train_10k/Disney_v2_9_01191_s2/frame2.jpg  
  inflating: /content/AnimeInterp/datas/datasets/train_10k/Disney_v2_9_01191_s2/frame3.jpg  
  inflating: /content/AnimeInterp/datas/datasets/train_10k/Disney_v2_9_01191_s2/frame1.jpg  
   creating: /content/AnimeInterp/datas/datasets/train_10k/Disney_v3_6_033220_s2/
  inflating: /content/AnimeInterp/datas/datasets/train_10k/Disney_v3_6_033220_s2/frame2.jpg  
  inflating: /content/AnimeInterp/datas/dataset

In [5]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import numpy as np

class FramePairDataset(Dataset):
    def __init__(self, images_dir, flows_dir, transform=None):
        """
        images_dir: Directory containing frame images.
        flows_dir: Directory containing pre-calculated optical flow data.
        transform: Transformations to apply to the frames.
        """
        self.images_dir = images_dir
        self.flows_dir = flows_dir
        self.transform = transform
        self.samples = self._load_samples()

    def _load_samples(self):
        samples = []
        for folder in sorted(os.listdir(self.images_dir)):
            image_folder_path = os.path.join(self.images_dir, folder)
            flow_folder_path = os.path.join(self.flows_dir, folder)

            if os.path.isdir(image_folder_path) and os.path.isdir(flow_folder_path):
                frames = sorted([f for f in os.listdir(image_folder_path) if f.endswith('.png')])
                flows = sorted([f for f in os.listdir(flow_folder_path) if f.endswith('.npy')])

                if len(frames) >= 3 and len(flows) >= 2:
                    frame1_path = os.path.join(image_folder_path, frames[0])
                    frame3_path = os.path.join(image_folder_path, frames[2])
                    flow1to3_path = os.path.join(flow_folder_path, flows[0])
                    flow3to1_path = os.path.join(flow_folder_path, flows[1])

                    samples.append((frame1_path, frame3_path, flow1to3_path, flow3to1_path))
        return samples

    def __getitem__(self, idx):
        frame1_path, frame3_path, flow1to3_path, flow3to1_path = self.samples[idx]

        frame1 = Image.open(frame1_path).convert('RGB')
        frame3 = Image.open(frame3_path).convert('RGB')
        flow1to3 = np.load(flow1to3_path)
        flow3to1 = np.load(flow3to1_path)

        if self.transform:
            frame1 = self.transform(frame1)
            frame3 = self.transform(frame3)

        # Convert numpy arrays to tensors
        flow1to3 = torch.from_numpy(flow1to3).float()
        flow3to1 = torch.from_numpy(flow3to1).float()

        return frame1, frame3, flow1to3, flow3to1

    def __len__(self):
        return len(self.samples)

# Transformations
transform = transforms.Compose([
    transforms.Resize((540, 960)),  # Adjust size accordingly
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Specify paths
images_dir = '/content/AnimeInterp/datas/datasets/test_2k_540p'
flows_dir = '/content/AnimeInterp/datas/datasets/test_2k_pre_calc_sgm_flows'

# Dataset and DataLoader initialization
dataset = FramePairDataset(images_dir, flows_dir, transform=transform)
loader = DataLoader(dataset, batch_size=4, shuffle=True)


In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FlowNetS(nn.Module):
    def __init__(self):
        super(FlowNetS, self).__init__()
        self.conv1 = nn.Conv2d(6, 64, kernel_size=7, stride=2, padding=3)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2)
        self.conv3_1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
        self.conv4_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1)
        self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv6 = nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1)
        self.conv6_1 = nn.Conv2d(1024, 1024, kernel_size=3, padding=1)

        # Predict flow at different scales
        self.predict_flow6 = nn.Conv2d(1024, 2, kernel_size=3, padding=1)
        self.upsample_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, stride=2, padding=1, output_padding=(1, 1))
        self.upsample_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, stride=2, padding=1, output_padding=(1, 1))

        self.predict_flow5 = nn.Conv2d(514, 2, kernel_size=3, padding=1)  # 512 from conv + 2 from upsampled flow
        self.predict_flow4 = nn.Conv2d(514, 2, kernel_size=3, padding=1)  # Again, 512 from conv + 2 from upsampled flow

    def forward(self, x):
        out_conv1 = F.relu(self.conv1(x))
        out_conv2 = F.relu(self.conv2(out_conv1))
        out_conv3 = F.relu(self.conv3(out_conv2))
        out_conv3_1 = F.relu(self.conv3_1(out_conv3))
        out_conv4 = F.relu(self.conv4(out_conv3_1))
        out_conv4_1 = F.relu(self.conv4_1(out_conv4))
        out_conv5 = F.relu(self.conv5(out_conv4_1))
        out_conv5_1 = F.relu(self.conv5_1(out_conv5))
        out_conv6 = F.relu(self.conv6(out_conv5_1))
        out_conv6_1 = F.relu(self.conv6_1(out_conv6))

        flow6 = self.predict_flow6(out_conv6_1)
        flow6_up = self.upsample_flow6_to_5(flow6)

        # Ensuring dimensions match for concatenation
        if out_conv5_1.size(2) != flow6_up.size(2) or out_conv5_1.size(3) != flow6_up.size(3):
            flow6_up = F.interpolate(flow6_up, size=(out_conv5_1.size(2), out_conv5_1.size(3)), mode='bilinear', align_corners=False)

        flow5 = self.predict_flow5(torch.cat([out_conv5_1, flow6_up], 1))
        flow5_up = self.upsample_flow5_to_4(flow5)

        # Ensuring dimensions match for concatenation
        if out_conv4_1.size(2) != flow5_up.size(2) or out_conv4_1.size(3) != flow5_up.size(3):
            flow5_up = F.interpolate(flow5_up, size=(out_conv4_1.size(2), out_conv4_1.size(3)), mode='bilinear', align_corners=False)

        flow4 = self.predict_flow4(torch.cat([out_conv4_1, flow5_up], 1))

        return flow4


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FlowNetS().to(device)
print(model)


FlowNetS(
  (conv1): Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
  (conv2): Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
  (conv3): Conv2d(128, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
  (conv3_1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv4_1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv5): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv5_1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv6): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv6_1): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (predict_flow6): Conv2d(1024, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (upsample_flow6_to_5): ConvTranspose2d(2, 2, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), output_pa

In [9]:
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()
def train(model, dataloader, optimizer, criterion, epochs):
    model.train()
    for epoch in range(epochs):
        print(f'Epoch {epoch+1}:')
        for data in dataloader:
            frame1, frame3, flow1to3, _ = data
            input_tensor = torch.cat([frame1, frame3], dim=1)  # Concatenate frame tensors
            input_tensor = input_tensor.to(device)
            flow1to3 = flow1to3.to(device)

            optimizer.zero_grad()
            predicted_flow = model(input_tensor)

            if predicted_flow.shape != flow1to3.shape:
                predicted_flow = F.interpolate(predicted_flow, size=(flow1to3.size(2), flow1to3.size(3)), mode='bilinear', align_corners=False)

            loss = criterion(predicted_flow, flow1to3)
            loss.backward()
            optimizer.step()

        print(f'Loss: {loss.item()}')

# Run the training
train(model, loader, optimizer, criterion, epochs=15)


Epoch 1:
Loss: 167.64645385742188
Epoch 2:
Loss: 96.15302276611328
Epoch 3:
Loss: 201.96580505371094
Epoch 4:
Loss: 126.0382308959961
Epoch 5:
Loss: 78.23834991455078
Epoch 6:
Loss: 365.24005126953125
Epoch 7:
Loss: 60.572635650634766
Epoch 8:
Loss: 305.84600830078125
Epoch 9:
Loss: 126.52880096435547
Epoch 10:
Loss: 91.4289779663086
Epoch 11:
Loss: 94.09759521484375
Epoch 12:
Loss: 61.66000747680664
Epoch 13:
Loss: 29.03998565673828
Epoch 14:
Loss: 95.7818603515625
Epoch 15:
Loss: 54.086063385009766


In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CorrelationLayer(nn.Module):
    """A basic correlation layer assuming a max displacement of 2 and stride of 1."""
    def __init__(self):
        super(CorrelationLayer, self).__init__()

    def forward(self, x1, x2):
        # Example correlation computation with maximum displacement of 2 and stride of 1
        b, c, h, w = x1.size()
        max_disp = 2
        stride = 1
        out_channels = max_disp * 2 + 1
        padded_x2 = F.pad(x2, [max_disp] * 4)
        result = []
        for i in range(-max_disp, max_disp + 1, stride):
            for j in range(-max_disp, max_disp + 1, stride):
                result.append((x1 * padded_x2[:, :, max_disp + i:max_disp + i + h, max_disp + j:max_disp + j + w]).mean(1, keepdim=True))
        return torch.cat(result, 1)

class FlowNetC(nn.Module):
    def __init__(self):
        super(FlowNetC, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2)
        self.correlation = CorrelationLayer()
        self.conv_reduce = nn.Conv2d(49, 256, kernel_size=1)  # Adjust based on correlation output
        self.conv3_1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
        self.conv5 = nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1)
        self.conv_final = nn.Conv2d(512, 2, kernel_size=3, padding=1)

    def forward(self, x1, x2):
        out1 = F.relu(self.conv1(x1))
        out1 = F.relu(self.conv2(out1))
        out1 = F.relu(self.conv3(out1))

        out2 = F.relu(self.conv1(x2))
        out2 = F.relu(self.conv2(out2))
        out2 = F.relu(self.conv3(out2))

        out_corr = self.correlation(out1, out2)
        out_corr = F.relu(self.conv_reduce(out_corr))

        out = F.relu(self.conv3_1(out_corr))
        out = F.relu(self.conv4(out))
        out = F.relu(self.conv5(out))
        flow = self.conv_final(out)

        return flow


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FlowNetC().to(device)

# Example to show model summary
print(model)


FlowNetC(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
  (conv2): Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
  (conv3): Conv2d(128, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
  (correlation): CorrelationLayer()
  (conv_reduce): Conv2d(49, 256, kernel_size=(1, 1), stride=(1, 1))
  (conv3_1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv5): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv_final): Conv2d(512, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)


In [12]:
def train(model, dataloader, optimizer, criterion, epochs):
    model.train()
    for epoch in range(epochs):
        for img1, img2, flow_true in loader:
            img1, img2, flow_true = img1.to(device), img2.to(device), flow_true.to(device)
            optimizer.zero_grad()
            flow_pred = model(img1, img2)
            loss = criterion(flow_pred, flow_true)
            loss.backward()
            optimizer.step()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
learning_rates = [1e-4, 1e-5]
batch_sizes = [4, 8]

for lr in learning_rates:
    for batch_size in batch_sizes:
        model = FlowNetC().to(device)
        optimizer = optim.Adam(model.parameters(), lr=lr)
        criterion = nn.MSELoss()
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        train(model, dataloader, optimizer, criterion, epochs=5)

ValueError: too many values to unpack (expected 3)

In [None]:
def train(model, dataloader, optimizer, criterion, epochs):
    model.train()
    for epoch in range(epochs):
        print(f'Epoch {epoch+1}:')
        for data in dataloader:
            frame1, frame3, flow1to3, _ = data
            input_tensor = torch.cat([frame1, frame3], dim=1)  # Concatenate frame tensors
            input_tensor = input_tensor.to(device)
            flow1to3 = flow1to3.to(device)

            optimizer.zero_grad()
            predicted_flow = model(input_tensor)

            if predicted_flow.shape != flow1to3.shape:
                predicted_flow = F.interpolate(predicted_flow, size=(flow1to3.size(2), flow1to3.size(3)), mode='bilinear', align_corners=False)

            loss = criterion(predicted_flow, flow1to3)
            loss.backward()
            optimizer.step()

            print(f'Loss: {loss.item()}')

# Run the training
train(model, loader, optimizer, criterion, epochs=5)


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import os
from PIL import Image
from torchvision import transforms

def _make_dataset(dir):
    """ Collects pairs of first and third frames from each directory of sequences. """
    frame_pairs = []
    for folder in sorted(os.listdir(dir)):
        folder_path = os.path.join(dir, folder)
        if os.path.isdir(folder_path):
            frames = sorted(os.listdir(folder_path))
            if len(frames) >= 3:
                # Select only the first and third frames
                first_frame_path = os.path.join(folder_path, frames[0])
                third_frame_path = os.path.join(folder_path, frames[2])
                frame_pairs.append((first_frame_path, third_frame_path))
    return frame_pairs

class FramePairDataset(Dataset):
    """ Dataset to load pairs of frames for optical flow computation. """
    def __init__(self, root, transform=None):
        self.frame_pairs = _make_dataset(root)
        self.transform = transform
        if len(self.frame_pairs) == 0:
            raise RuntimeError("Found 0 files in subfolders of: " + root)

    def __getitem__(self, index):
        frame1, frame3 = self.frame_pairs[index]
        image1 = Image.open(frame1).convert('RGB')
        image3 = Image.open(frame3).convert('RGB')

        if self.transform:
            image1 = self.transform(image1)
            image3 = self.transform(image3)

        return image1, image3

    def __len__(self):
        return len(self.frame_pairs)


transform = transforms.Compose([
    transforms.Resize((360, 640)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = FramePairDataset('/content/AnimeInterp/datas/datasets/train_10k/', transform=transform)
loader = DataLoader(dataset, batch_size=4, shuffle=True)

# Iterate through the DataLoader
for image1, image3 in loader:
    print(image1.shape, image3.shape)  # Output the shapes to verify


torch.Size([4, 3, 360, 640]) torch.Size([4, 3, 360, 640])
torch.Size([4, 3, 360, 640]) torch.Size([4, 3, 360, 640])
torch.Size([4, 3, 360, 640]) torch.Size([4, 3, 360, 640])
torch.Size([4, 3, 360, 640]) torch.Size([4, 3, 360, 640])
torch.Size([4, 3, 360, 640]) torch.Size([4, 3, 360, 640])
torch.Size([4, 3, 360, 640]) torch.Size([4, 3, 360, 640])
torch.Size([4, 3, 360, 640]) torch.Size([4, 3, 360, 640])
torch.Size([4, 3, 360, 640]) torch.Size([4, 3, 360, 640])
torch.Size([4, 3, 360, 640]) torch.Size([4, 3, 360, 640])
torch.Size([4, 3, 360, 640]) torch.Size([4, 3, 360, 640])
torch.Size([4, 3, 360, 640]) torch.Size([4, 3, 360, 640])
torch.Size([4, 3, 360, 640]) torch.Size([4, 3, 360, 640])
torch.Size([4, 3, 360, 640]) torch.Size([4, 3, 360, 640])
torch.Size([4, 3, 360, 640]) torch.Size([4, 3, 360, 640])
torch.Size([4, 3, 360, 640]) torch.Size([4, 3, 360, 640])
torch.Size([4, 3, 360, 640]) torch.Size([4, 3, 360, 640])
torch.Size([4, 3, 360, 640]) torch.Size([4, 3, 360, 640])
torch.Size([4,

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import os
from PIL import Image
from torchvision import transforms

def _make_dataset(dir):
    """ Collects pairs of first and third frames from each directory of sequences. """
    frame_pairs = []
    for folder in sorted(os.listdir(dir)):
        folder_path = os.path.join(dir, folder)
        if os.path.isdir(folder_path):
            frames = sorted(os.listdir(folder_path))
            if len(frames) >= 3:
                # Select only the first and third frames
                first_frame_path = os.path.join(folder_path, frames[0])
                third_frame_path = os.path.join(folder_path, frames[2])
                frame_pairs.append((first_frame_path, third_frame_path))
    return frame_pairs

class FramePairDataset(Dataset):
    """ Dataset to load pairs of frames for optical flow computation. """
    def __init__(self, root, transform=None):
        self.frame_pairs = _make_dataset(root)
        self.transform = transform
        if len(self.frame_pairs) == 0:
            raise RuntimeError("Found 0 files in subfolders of: " + root)

    def __getitem__(self, index):
        frame1, frame3 = self.frame_pairs[index]
        image1 = Image.open(frame1).convert('RGB')
        image3 = Image.open(frame3).convert('RGB')

        if self.transform:
            image1 = self.transform(image1)
            image3 = self.transform(image3)

        return image1, image3

    def __len__(self):
        return len(self.frame_pairs)


transform = transforms.Compose([
    transforms.Resize((360, 640)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = FramePairDataset('/content/AnimeInterp/datas/datasets/train_10k/', transform=transform)
loader = DataLoader(dataset, batch_size=4, shuffle=True)

# Iterate through the DataLoader
for image1, image3 in loader:
    print(image1.shape, image3.shape)  # Output the shapes to verify

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class FlowNetS(nn.Module):
    def __init__(self):
        super(FlowNetS, self).__init__()
        self.conv1 = nn.Conv2d(6, 64, kernel_size=7, stride=2, padding=3)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2)
        self.conv3_1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
        self.conv4_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1)
        self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv6 = nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1)
        self.conv6_1 = nn.Conv2d(1024, 1024, kernel_size=3, padding=1)

        # Predict flow at different scales
        self.predict_flow6 = nn.Conv2d(1024, 2, kernel_size=3, padding=1)
        self.upsample_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, stride=2, padding=1)
        self.upsample_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, stride=2, padding=1)

        self.predict_flow5 = nn.Conv2d(1026, 2, kernel_size=3, padding=1) # 1024 from conv + 2 from upsampled flow
        self.predict_flow4 = nn.Conv2d(514, 2, kernel_size=3, padding=1)  # 512 from conv + 2 from upsampled flow

    def forward(self, x):
        out_conv1 = F.relu(self.conv1(x))
        out_conv2 = F.relu(self.conv2(out_conv1))
        out_conv3 = F.relu(self.conv3(out_conv2))
        out_conv3_1 = F.relu(self.conv3_1(out_conv3))
        out_conv4 = F.relu(self.conv4(out_conv3_1))
        out_conv4_1 = F.relu(self.conv4_1(out_conv4))
        out_conv5 = F.relu(self.conv5(out_conv4_1))
        out_conv5_1 = F.relu(self.conv5_1(out_conv5))
        out_conv6 = F.relu(self.conv6(out_conv5_1))
        out_conv6_1 = F.relu(self.conv6_1(out_conv6))

        flow6 = self.predict_flow6(out_conv6_1)
        flow6_up = self.upsample_flow6_to_5(flow6)

        flow5 = self.predict_flow5(torch.cat([out_conv5_1, flow6_up], 1))
        flow5_up = self.upsample_flow5_to_4(flow5)

        flow4 = self.predict_flow4(torch.cat([out_conv4_1, flow5_up], 1))

        return flow4

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FlowNetS().to(device)
print(model)


In [8]:
import torch.optim as optim

In [None]:
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()


In [None]:
def train(model, dataloader, optimizer, criterion, epochs):
    model.train()
    for epoch in range(epochs):
        for images1, images3 in dataloader:
            images1, images3 = images1.to(device), images3.to(device)
            # input by concatenating images along the color channel
            input = torch.cat([images1, images3], dim=1)

            optimizer.zero_grad()
            output = model(input)
            #target optical flow which we do not have
            #target_flow =
            # loss = criterion(output, target_flow)


            loss.backward()
            optimizer.step()
            print(f'Epoch {epoch}, Loss {loss.item()}')

train(model, loader, optimizer, criterion, epochs=20)
