In [1]:
pwd

'/media/anil/hdd2/nihal/gmflow'

In [18]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"  # Use both GPUs

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import argparse
import numpy as np
from torch.nn import init
from tqdm.notebook import tqdm

from data import build_train_dataset
from gmflow.gmflow import GMFlow
from loss import flow_loss_func
from evaluate import validate_chairs, validate_things, validate_sintel, validate_kitti
from utils.logger import Logger
from utils import misc

from gmflow.backbone import CNNEncoder
from gmflow.transformer import FeatureTransformer, FeatureFlowAttention
from gmflow.matching import global_correlation_softmax, local_correlation_softmax
from gmflow.geometry import flow_warp
from gmflow.utils import normalize_img, feature_add_position
from gmflow.trident_conv import MultiScaleTridentConv

In [19]:
class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, dilation=1):
        super(DepthwiseSeparableConv, self).__init__()
        self.depthwise = nn.Conv2d(
            in_channels, 
            in_channels, 
            kernel_size=3, 
            stride=stride, 
            padding=dilation,
            groups=in_channels,
            dilation=dilation,
            bias=False
        )
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        
    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x

class SqueezeExcitation(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(SqueezeExcitation, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Conv2d(in_channels, in_channels // reduction, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(in_channels // reduction, in_channels, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        scale = self.avg_pool(x)
        scale = self.fc1(scale)
        scale = self.relu(scale)
        scale = self.fc2(scale)
        scale = self.sigmoid(scale)
        return x * scale

class ResidualBlockWithSE(nn.Module):
    def __init__(self, in_planes, planes, norm_layer=nn.InstanceNorm2d, stride=1, dilation=1):
        super(ResidualBlockWithSE, self).__init__()
        
        self.conv1 = DepthwiseSeparableConv(in_planes, planes, stride=stride, dilation=dilation)
        self.conv2 = DepthwiseSeparableConv(planes, planes, dilation=dilation)
        
        self.norm1 = norm_layer(planes)
        self.norm2 = norm_layer(planes)
        
        self.se = SqueezeExcitation(planes)
        self.relu = nn.ReLU(inplace=True)

        if not stride == 1 or in_planes != planes:
            self.norm3 = norm_layer(planes)
            self.downsample = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride),
                self.norm3
            )
        else:
            self.downsample = None

    def forward(self, x):
        identity = x
        
        out = self.conv1(x)
        out = self.norm1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.norm2(out)
        
        out = self.se(out)
        
        if self.downsample is not None:
            identity = self.downsample(x)
        
        out += identity
        out = self.relu(out)
        
        return out

class EnhancedCNNEncoder(nn.Module):
    def __init__(self, output_dim=128, norm_layer=nn.InstanceNorm2d, num_output_scales=1):
        super(EnhancedCNNEncoder, self).__init__()
        self.num_branch = num_output_scales
        
        feature_dims = [64, 96, 128]
        
        # Initial convolution with same parameters as original
        self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False)
        self.norm1 = norm_layer(feature_dims[0])
        self.relu1 = nn.ReLU(inplace=True)
        
        # Create residual blocks
        self.in_planes = feature_dims[0]
        self.layer1 = self._make_layer(feature_dims[0], stride=1, norm_layer=norm_layer)
        self.layer2 = self._make_layer(feature_dims[1], stride=2, norm_layer=norm_layer)
        
        # Maintain original stride logic
        stride = 2 if num_output_scales == 1 else 1
        self.layer3 = self._make_layer(feature_dims[2], stride=stride, norm_layer=norm_layer)
        
        # Final 1x1 conv to match output dimension
        self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0)
        
        # Initialize Trident Conv if multiple scales needed
        if self.num_branch > 1:
            if self.num_branch == 4:
                strides = (1, 2, 4, 8)
            elif self.num_branch == 3:
                strides = (1, 2, 4)
            elif self.num_branch == 2:
                strides = (1, 2)
            else:
                raise ValueError
                
            self.trident_conv = MultiScaleTridentConv(
                output_dim, 
                output_dim,
                kernel_size=3,
                strides=strides,
                paddings=1,
                num_branch=self.num_branch
            )
        
        # Weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
                if m.weight is not None:
                    nn.init.constant_(m.weight, 1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d):
        layer1 = ResidualBlockWithSE(self.in_planes, dim, norm_layer=norm_layer, 
                                   stride=stride, dilation=dilation)
        layer2 = ResidualBlockWithSE(dim, dim, norm_layer=norm_layer, 
                                   stride=1, dilation=dilation)
        
        layers = (layer1, layer2)
        self.in_planes = dim
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.relu1(x)
        
        x = self.layer1(x)  # 1/2
        x = self.layer2(x)  # 1/4
        x = self.layer3(x)  # 1/8 or 1/4 depending on num_output_scales
        
        x = self.conv2(x)
        
        if self.num_branch > 1:
            out = self.trident_conv([x] * self.num_branch)
        else:
            out = [x]
            
        return out

In [20]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F

# class DilatedConv(nn.Module):
#     def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1):
#         super().__init__()
#         padding = dilation * (kernel_size - 1) // 2
#         self.conv = nn.Conv2d(
#             in_channels, 
#             out_channels,
#             kernel_size=kernel_size,
#             stride=stride,
#             padding=padding,
#             dilation=dilation,
#             bias=False
#         )
    
#     def forward(self, x):
#         return self.conv(x)

# class ASPP(nn.Module):
#     def __init__(self, in_channels, out_channels):
#         super().__init__()
#         dilations = [1, 6, 12, 18]
        
#         # Keep channels same as input throughout ASPP
#         self.aspp = nn.ModuleList()
#         for dilation in dilations:
#             self.aspp.append(
#                 nn.Sequential(
#                     nn.Conv2d(in_channels, in_channels, 1, bias=False),
#                     nn.InstanceNorm2d(in_channels),
#                     nn.ReLU(inplace=True),
#                     DilatedConv(in_channels, in_channels, dilation=dilation),
#                     nn.InstanceNorm2d(in_channels),
#                     nn.ReLU(inplace=True)
#                 )
#             )
        
#         self.global_branch = nn.Sequential(
#             nn.AdaptiveAvgPool2d(1),
#             nn.Conv2d(in_channels, in_channels, 1, bias=False),
#             nn.ReLU(inplace=True)
#         )
        
#         # Ensure output channels match input
#         self.output_conv = nn.Sequential(
#             nn.Conv2d(in_channels * 5, out_channels, 1, bias=False),
#             nn.InstanceNorm2d(out_channels),
#             nn.ReLU(inplace=True)
#         )
    
#     def forward(self, x):
#         size = x.size()[2:]
        
#         res = []
#         for aspp_module in self.aspp:
#             res.append(aspp_module(x))
        
#         global_context = self.global_branch(x)
#         global_context = F.interpolate(
#             global_context,
#             size=size,
#             mode='bilinear',
#             align_corners=False
#         )
        
#         res.append(global_context)
#         combined = torch.cat(res, dim=1)
        
#         return self.output_conv(combined)

# class EnhancedResidualBlock(nn.Module):
#     def __init__(self, in_planes, planes, norm_layer=nn.InstanceNorm2d, stride=1, dilation=1):
#         super().__init__()
        
#         # Keep the same structure as original ResidualBlock
#         self.conv1 = DilatedConv(in_planes, planes, stride=stride, dilation=dilation)
#         self.conv2 = DilatedConv(planes, planes, dilation=dilation)
        
#         self.norm1 = norm_layer(planes)
#         self.norm2 = norm_layer(planes)
        
#         self.relu = nn.ReLU(inplace=True)
        
#         # Match original downsample exactly
#         if stride != 1 or in_planes != planes:
#             self.norm3 = norm_layer(planes)
#             self.downsample = nn.Sequential(
#                 nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride),
#                 self.norm3
#             )
#         else:
#             self.downsample = None

#     def forward(self, x):
#         identity = x
        
#         out = self.conv1(x)
#         out = self.norm1(out)
#         out = self.relu(out)
        
#         out = self.conv2(out)
#         out = self.norm2(out)
        
#         if self.downsample is not None:
#             identity = self.downsample(x)
        
#         out += identity
#         out = self.relu(out)
        
#         return out

# class EnhancedCNNEncoder(nn.Module):
#     def __init__(self, output_dim=128, norm_layer=nn.InstanceNorm2d, num_output_scales=1):
#         super().__init__()
#         self.num_branch = num_output_scales
        
#         feature_dims = [64, 96, 128]
        
#         # Keep original first conv
#         self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False)
#         self.norm1 = norm_layer(feature_dims[0])
#         self.relu1 = nn.ReLU(inplace=True)
        
#         # Main layers
#         self.in_planes = feature_dims[0]
#         self.layer1 = self._make_layer(feature_dims[0], stride=1, norm_layer=norm_layer)
#         self.layer2 = self._make_layer(feature_dims[1], stride=2, norm_layer=norm_layer)
        
#         # Add ASPP between layer2 and layer3
#         self.aspp = ASPP(feature_dims[1], feature_dims[2])  # output: 128 channels
        
#         # IMPORTANT CHANGE: Update in_planes to match ASPP output
#         self.in_planes = feature_dims[2]  # Now 128 channels
        
#         # layer3 receives output from ASPP (128 channels)
#         stride = 2 if num_output_scales == 1 else 1
#         self.layer3 = self._make_layer(feature_dims[2], stride=stride, norm_layer=norm_layer)
        
#         # Final projection
#         self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0)
        
#         if self.num_branch > 1:
#             if self.num_branch == 4:
#                 strides = (1, 2, 4, 8)
#             elif self.num_branch == 3:
#                 strides = (1, 2, 4)
#             elif self.num_branch == 2:
#                 strides = (1, 2)
#             else:
#                 raise ValueError
            
#             self.trident_conv = MultiScaleTridentConv(
#                 output_dim, 
#                 output_dim,
#                 kernel_size=3,
#                 strides=strides,
#                 paddings=1,
#                 num_branch=self.num_branch
#             )
        
#         self._initialize_weights()

        
#     def _initialize_weights(self):
#         for m in self.modules():
#             if isinstance(m, nn.Conv2d):
#                 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
#             elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
#                 if m.weight is not None:
#                     nn.init.constant_(m.weight, 1)
#                 if m.bias is not None:
#                     nn.init.constant_(m.bias, 0)

#     def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d):
#         # Create layer with current in_planes (important for proper channel handling)
#         layer1 = EnhancedResidualBlock(self.in_planes, dim, norm_layer=norm_layer, 
#                                      stride=stride, dilation=dilation)
#         layer2 = EnhancedResidualBlock(dim, dim, norm_layer=norm_layer, 
#                                      stride=1, dilation=dilation)
        
#         self.in_planes = dim
#         return nn.Sequential(layer1, layer2)

#     def forward(self, x):
#         # Add dimension prints to debug
#         print(f"Initial input: {x.shape}")
        
#         x = self.conv1(x)
#         x = self.norm1(x)
#         x = self.relu1(x)
#         print(f"After conv1: {x.shape}")  # Should be [B, 64, H/2, W/2]
        
#         x = self.layer1(x)
#         print(f"After layer1: {x.shape}")  # Should be [B, 64, H/2, W/2]
        
#         x = self.layer2(x)
#         print(f"After layer2: {x.shape}")  # Should be [B, 96, H/4, W/4]
        
#         x = self.aspp(x)
#         print(f"After ASPP: {x.shape}")    # Should be [B, 128, H/4, W/4]
        
#         x = self.layer3(x)
#         print(f"After layer3: {x.shape}")  # Should be [B, 128, H/8 or H/4, W/8 or W/4]
        
#         x = self.conv2(x)
#         print(f"After conv2: {x.shape}")   # Should be [B, output_dim, H/8 or H/4, W/8 or W/4]
        
#         if self.num_branch > 1:
#             out = self.trident_conv([x] * self.num_branch)
#         else:
#             out = [x]
        
#         return out

In [55]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init

class SlotAttentionFlow(nn.Module):
    def __init__(self, in_channels, iters=3, eps=1e-8):
        super(SlotAttentionFlow, self).__init__()
        
        self.num_slots = 2
        self.iters = iters
        self.eps = eps
        self.scale = in_channels ** -0.5
        self.dim = in_channels
        
        # Feature projections
        self.q_proj = nn.Linear(in_channels, in_channels)
        self.k_proj = nn.Linear(in_channels, in_channels)
        self.v_proj = nn.Linear(in_channels, in_channels)
        
        # Slot processing
        self.slot_init = nn.Sequential(
            nn.Linear(in_channels, in_channels),
            nn.LayerNorm(in_channels),
            nn.ReLU(inplace=True)
        )
        
        # GRU update
        self.gru = nn.GRUCell(in_channels, in_channels)
        
        # Flow refinement - make it resolution-independent
        self.flow_mlp = nn.Sequential(
            nn.Linear(in_channels, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 2)  # Output 2D flow per position
        )
        
        # Norms
        self.norm_feat = nn.LayerNorm(in_channels)
        self.norm_slots = nn.LayerNorm(in_channels)
        
        self._reset_parameters()
        
    def _reset_parameters(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    init.zeros_(m.bias)
    
    def forward(self, feature0, flow, local_window_attn=False, local_window_radius=1):
        print(f"\nForward input shapes:")
        print(f"feature0: {feature0.shape}")
        print(f"flow: {flow.shape}")
        print(f"local_window_attn: {local_window_attn}")
        
        if local_window_attn:
            return self.forward_local_window_attn(feature0, flow, local_window_radius)
            
        b, c, h, w = feature0.size()
        print(f"\nProcessing global attention with dims b={b}, c={c}, h={h}, w={w}")
        
        # Reshape feature for attention
        feature = feature0.view(b, c, -1).permute(0, 2, 1)  # [B, H*W, C]
        feature = self.norm_feat(feature)
        print(f"Normalized feature shape: {feature.shape}")
        
        # Initialize slots from feature
        slots = feature.mean(dim=1, keepdim=True)  # [B, 1, C]
        slots = self.slot_init(slots)
        slots = slots.repeat(1, self.num_slots, 1)  # [B, 2, C]
        print(f"Initial slots shape: {slots.shape}")
        
        # Prepare KV
        k = self.k_proj(feature)  # [B, H*W, C]
        v = self.v_proj(feature)  # [B, H*W, C]
        print(f"K shape: {k.shape}, V shape: {v.shape}")
        
        # Iterative refinement
        for iter_idx in range(self.iters):
            print(f"\nIteration {iter_idx + 1}")
            slots_prev = slots
            
            q = self.q_proj(self.norm_slots(slots))  # [B, 2, C]
            print(f"Query shape: {q.shape}")
            
            attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale  # [B, 2, H*W]
            attn = F.softmax(attn, dim=-1)
            print(f"Attention shape: {attn.shape}")
            
            updates = torch.matmul(attn, v)  # [B, 2, C]
            print(f"Updates shape: {updates.shape}")
            
            slots = self.gru(
                updates.reshape(-1, self.dim),
                slots_prev.reshape(-1, self.dim)
            )
            slots = slots.reshape(b, self.num_slots, self.dim)
            print(f"Updated slots shape: {slots.shape}")
        
        # Process each position with the slot features
        feature_with_slots = feature + torch.matmul(attn.transpose(-2, -1), slots)  # [B, H*W, C]
        flow_updates = self.flow_mlp(feature_with_slots)  # [B, H*W, 2]
        flow_updates = flow_updates.permute(0, 2, 1).view(b, 2, h, w)  # [B, 2, H, W]
        
        print(f"\nFlow updates shape: {flow_updates.shape}")
        print(f"Input flow shape: {flow.shape}")
        
        refined_flow = flow + flow_updates
        print(f"Refined flow shape: {refined_flow.shape}")
        
        return refined_flow
    
    def forward_local_window_attn(self, feature0, flow, local_window_radius):
        b, c, h, w = feature0.size()
        print(f"\nLocal window attention - input dims: b={b}, c={c}, h={h}, w={w}")
        
        kernel_size = 2 * local_window_radius + 1
        print(f"Kernel size: {kernel_size}")
        
        # Process features
        feature = feature0.view(b, c, -1).permute(0, 2, 1).contiguous()  # [B, H*W, C]
        feature = self.norm_feat(feature)
        print(f"Normalized feature shape: {feature.shape}")
        
        # Initialize slots
        slots = feature.mean(dim=1, keepdim=True)  # [B, 1, C]
        slots = self.slot_init(slots)
        slots = slots.repeat(1, self.num_slots, 1)  # [B, 2, C]
        print(f"Initial slots shape: {slots.shape}")
        
        # Process local windows
        for iter_idx in range(self.iters):
            print(f"\nLocal iteration {iter_idx + 1}")
            
            q = self.q_proj(self.norm_slots(slots))  # [B, 2, C]
            k = self.k_proj(feature)  # [B, H*W, C]
            print(f"Query shape: {q.shape}, Key shape: {k.shape}")
            
            attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale  # [B, 2, H*W]
            attn = F.softmax(attn, dim=-1)
            print(f"Attention shape: {attn.shape}")
            
            updates = torch.matmul(attn, k)  # [B, 2, C]
            print(f"Updates shape: {updates.shape}")
            
            slots = self.gru(
                updates.reshape(-1, self.dim),
                slots.reshape(-1, self.dim)
            )
            slots = slots.reshape(b, self.num_slots, self.dim)
            print(f"Updated slots shape: {slots.shape}")
        
        # Process each position with the slot features
        feature_with_slots = feature + torch.matmul(attn.transpose(-2, -1), slots)  # [B, H*W, C]
        flow_updates = self.flow_mlp(feature_with_slots)  # [B, H*W, 2]
        flow_updates = flow_updates.permute(0, 2, 1).view(b, 2, h, w)  # [B, 2, H, W]
        
        print(f"\nFlow updates shape: {flow_updates.shape}")
        
        refined_flow = flow + flow_updates
        print(f"Final refined flow shape: {refined_flow.shape}")
        
        return refined_flow

In [56]:
class GMFlow(nn.Module):
    def __init__(self,
                 num_scales=1,
                 upsample_factor=8,
                 feature_channels=128,
                 attention_type='swin',
                 num_transformer_layers=6,
                 ffn_dim_expansion=4,
                 num_head=1,
                 **kwargs,
                 ):
        super(GMFlow, self).__init__()

        self.num_scales = num_scales
        self.feature_channels = feature_channels
        self.upsample_factor = upsample_factor
        self.attention_type = attention_type
        self.num_transformer_layers = num_transformer_layers

        # CNN backbone
#         self.backbone = CNNEncoder(output_dim=feature_channels, num_output_scales=num_scales)
        self.backbone = EnhancedCNNEncoder(output_dim=feature_channels, num_output_scales=num_scales)
        # Transformer
        self.transformer = FeatureTransformer(num_layers=num_transformer_layers,
                                              d_model=feature_channels,
                                              nhead=num_head,
                                              attention_type=attention_type,
                                              ffn_dim_expansion=ffn_dim_expansion,
                                              )

        # flow propagation with self-attn
#         self.feature_flow_attn = FeatureFlowAttention(in_channels=feature_channels)
        self.feature_flow_attn = SlotAttentionFlow(
            in_channels=feature_channels,
            iters=3  # adjust based on your needs
        )       
        # convex upsampling: concat feature0 and flow as input
        self.upsampler = nn.Sequential(nn.Conv2d(2 + feature_channels, 256, 3, 1, 1),
                                       nn.ReLU(inplace=True),
                                       nn.Conv2d(256, upsample_factor ** 2 * 9, 1, 1, 0))

    def extract_feature(self, img0, img1):
        concat = torch.cat((img0, img1), dim=0)  # [2B, C, H, W]
        features = self.backbone(concat)  # list of [2B, C, H, W], resolution from high to low

        # reverse: resolution from low to high
        features = features[::-1]

        feature0, feature1 = [], []

        for i in range(len(features)):
            feature = features[i]
            chunks = torch.chunk(feature, 2, 0)  # tuple
            feature0.append(chunks[0])
            feature1.append(chunks[1])

        return feature0, feature1

    def upsample_flow(self, flow, feature, bilinear=False, upsample_factor=8,
                      ):
        if bilinear:
            up_flow = F.interpolate(flow, scale_factor=upsample_factor,
                                    mode='bilinear', align_corners=True) * upsample_factor

        else:
            # convex upsampling
            concat = torch.cat((flow, feature), dim=1)

            mask = self.upsampler(concat)
            b, flow_channel, h, w = flow.shape
            mask = mask.view(b, 1, 9, self.upsample_factor, self.upsample_factor, h, w)  # [B, 1, 9, K, K, H, W]
            mask = torch.softmax(mask, dim=2)

            up_flow = F.unfold(self.upsample_factor * flow, [3, 3], padding=1)
            up_flow = up_flow.view(b, flow_channel, 9, 1, 1, h, w)  # [B, 2, 9, 1, 1, H, W]

            up_flow = torch.sum(mask * up_flow, dim=2)  # [B, 2, K, K, H, W]
            up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)  # [B, 2, K, H, K, W]
            up_flow = up_flow.reshape(b, flow_channel, self.upsample_factor * h,
                                      self.upsample_factor * w)  # [B, 2, K*H, K*W]

        return up_flow

    def forward(self, img0, img1,
                attn_splits_list=None,
                corr_radius_list=None,
                prop_radius_list=None,
                pred_bidir_flow=False,
                **kwargs,
                ):

        results_dict = {}
        flow_preds = []

        img0, img1 = normalize_img(img0, img1)  # [B, 3, H, W]

        # resolution low to high
        feature0_list, feature1_list = self.extract_feature(img0, img1)  # list of features

        flow = None

        assert len(attn_splits_list) == len(corr_radius_list) == len(prop_radius_list) == self.num_scales

        for scale_idx in range(self.num_scales):
            feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx]

            if pred_bidir_flow and scale_idx > 0:
                # predicting bidirectional flow with refinement
                feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat((feature1, feature0), dim=0)

            upsample_factor = self.upsample_factor * (2 ** (self.num_scales - 1 - scale_idx))

            if scale_idx > 0:
                flow = F.interpolate(flow, scale_factor=2, mode='bilinear', align_corners=True) * 2
            print(f"shape of feature1: {feature1.shape}")
            if flow is not None:
                flow = flow.detach()
                feature1 = flow_warp(feature1, flow)  # [B, C, H, W]

            attn_splits = attn_splits_list[scale_idx]
            corr_radius = corr_radius_list[scale_idx]
            prop_radius = prop_radius_list[scale_idx]

            # add position to features
            feature0, feature1 = feature_add_position(feature0, feature1, attn_splits, self.feature_channels)

            # Transformer
            feature0, feature1 = self.transformer(feature0, feature1, attn_num_splits=attn_splits)

            # correlation and softmax
            if corr_radius == -1:  # global matching
                flow_pred = global_correlation_softmax(feature0, feature1, pred_bidir_flow)[0]
            else:  # local matching
                flow_pred = local_correlation_softmax(feature0, feature1, corr_radius)[0]

            # flow or residual flow
            flow = flow + flow_pred if flow is not None else flow_pred

            # upsample to the original resolution for supervison
            if self.training:  # only need to upsample intermediate flow predictions at training time
                flow_bilinear = self.upsample_flow(flow, None, bilinear=True, upsample_factor=upsample_factor)
                flow_preds.append(flow_bilinear)

            # flow propagation with self-attn
            if pred_bidir_flow and scale_idx == 0:
                feature0 = torch.cat((feature0, feature1), dim=0)  # [2*B, C, H, W] for propagation
            flow = self.feature_flow_attn(feature0, flow.detach(),
                                          local_window_attn=prop_radius > 0,
                                          local_window_radius=prop_radius)

            # bilinear upsampling at training time except the last one
            if self.training and scale_idx < self.num_scales - 1:
                flow_up = self.upsample_flow(flow, feature0, bilinear=True, upsample_factor=upsample_factor)
                flow_preds.append(flow_up)

            if scale_idx == self.num_scales - 1:
                flow_up = self.upsample_flow(flow, feature0)
                flow_preds.append(flow_up)

        results_dict.update({'flow_preds': flow_preds})

        return results_dict


In [57]:
def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--checkpoint_dir', default='checkpoints/gmflow_with_refine', type=str)
    parser.add_argument('--stage', default='chairs', type=str)
    parser.add_argument('--image_size', default=[384, 512], type=int, nargs='+')
    parser.add_argument('--padding_factor', default=32, type=int)
    parser.add_argument('--num_scales', default=2, type=int)
    parser.add_argument('--attn_splits_list', default=[2, 8], type=int, nargs='+')
    parser.add_argument('--corr_radius_list', default=[-1, 4], type=int, nargs='+')
    parser.add_argument('--prop_radius_list', default=[-1, 1], type=int, nargs='+')
    parser.add_argument('--num_steps', default=100000, type=int)
    parser.add_argument('--batch_size', default=2, type=int)
    parser.add_argument('--lr', default=4e-4, type=float)
    parser.add_argument('--weight_decay', default=1e-4, type=float)
    parser.add_argument('--gamma', default=0.9, type=float)
#     parser.add_argument('--image_size', default=[384, 512], type=int, nargs='+')
    # Add other necessary arguments from the original script
    args = parser.parse_args([])  # Parse empty list to use defaults
    return args

args = get_args()

In [58]:
model = GMFlow(num_scales=args.num_scales,
               feature_channels=128,
               upsample_factor=4,
               num_head=1,
               attention_type='swin',
               ffn_dim_expansion=4,
               num_transformer_layers=6)

model = torch.nn.DataParallel(model).cuda()

optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps + 10,
                                                pct_start=0.05, cycle_momentum=False, anneal_strategy='cos')

# print(model)

In [59]:
def count_parameters(model):
 
    if isinstance(model, torch.nn.DataParallel):
        model = model.module
    return sum(p.numel() for p in model.parameters())

total_params = count_parameters(model)
print(f"Total Parameters: {total_params}") #3697524

Total Parameters: 3999382


In [60]:
train_dataset = build_train_dataset(args)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=True)

In [61]:
summary_writer = SummaryWriter(args.checkpoint_dir)
logger = Logger(scheduler, summary_writer, summary_freq=100)

In [62]:
def train_epoch(epoch):
    model.train()
    for i, sample in enumerate(tqdm(train_loader, desc=f'Epoch {epoch}', unit='batch')):
        img1, img2, flow_gt, valid = [x.cuda() for x in sample]
        
        results_dict = model(img1, img2,
                             attn_splits_list=args.attn_splits_list,
                             corr_radius_list=args.corr_radius_list,
                             prop_radius_list=args.prop_radius_list)
        
        flow_preds = results_dict['flow_preds']
        
        loss, metrics = flow_loss_func(flow_preds, flow_gt, valid, gamma=args.gamma)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        logger.push(metrics)
        logger.add_image_summary(img1, img2, flow_preds, flow_gt)
        
        if i % 100 == 0:
            tqdm.write(f"Epoch {epoch}, Step {i}, Loss: {loss.item()}")

In [None]:
num_epochs = args.num_steps // len(train_loader) + 1
for epoch in range(num_epochs):
    train_epoch(epoch)
    
    # Validation
    if epoch % 10 == 0:  # Adjust validation frequency as needed
        results = {}
        results.update(validate_chairs(model))
        results.update(validate_things(model))
        results.update(validate_sintel(model))
        results.update(validate_kitti(model))
        logger.write_dict(results)

    # Save checkpoint
    if epoch % 10 == 0:  # Adjust saving frequency as needed
        torch.save({
            'model': model.module.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch,
        }, f'{args.checkpoint_dir}/checkpoint_epoch_{epoch}.pth')

print("Training completed!")

Epoch 0:   0%|          | 0/11116 [00:00<?, ?batch/s]

shape of feature1: torch.Size([1, 128, 48, 64])
shape of feature1: torch.Size([1, 128, 48, 64])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Normalized feature shape: torch.Size([1, 3072, 128])
Initial slots shape: torch.Size([1, 2, 128])
K shape: torch.Size([1, 3072, 128]), V shape: torch.Size([1, 3072, 128])

Iteration 1
Query shape: torch.Size([1, 2, 128])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Attention shape: torch.Size([1, 2, 3072])
Normalized feature shape: torch.Size([1, 3072, 128])
Updates shape: torch.Size([1, 2, 128])
Updated slots shape: torch.Size([1, 2, 128])

Iteration 2
Initial slots shape: torch.Size([1, 2, 128])
Query shape: torch.Size([1, 2, 128])
K shape: torch.Size([1, 3072, 128]), V shape:

shape of feature1: torch.Size([1, 128, 48, 64])
shape of feature1: torch.Size([1, 128, 48, 64])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Normalized feature shape: torch.Size([1, 3072, 128])
Initial slots shape: torch.Size([1, 2, 128])
K shape: torch.Size([1, 3072, 128]), V shape: torch.Size([1, 3072, 128])

Iteration 1
Query shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])
Updates shape: torch.Size([1, 2, 128])
Updated slots shape: torch.Size([1, 2, 128])

Iteration 2
Query shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])
Updates shape: torch.Size([1, 2, 128])
Updated slots shape: torch.Size([1, 2, 128])

Iteration 3
Query shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])
Updates shape: torch.Size([1, 2, 128])
Updated slots shape: torch.Size([1, 2, 128])

Flow updates shape: to

shape of feature1: torch.Size([1, 128, 48, 64])
shape of feature1: torch.Size([1, 128, 48, 64])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Normalized feature shape: torch.Size([1, 3072, 128])
Initial slots shape: torch.Size([1, 2, 128])
K shape: torch.Size([1, 3072, 128]), V shape: torch.Size([1, 3072, 128])

Iteration 1
Query shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])
Updates shape: torch.Size([1, 2, 128])
Updated slots shape: torch.Size([1, 2, 128])

Iteration 2
Query shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Updates shape: torch.Size([1, 2, 128])
Normalized feature shape: torch.Size([1, 3072, 128])
Updated s

shape of feature1: torch.Size([1, 128, 48, 64])
shape of feature1: torch.Size([1, 128, 48, 64])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Normalized feature shape: torch.Size([1, 3072, 128])
Initial slots shape: torch.Size([1, 2, 128])
K shape: torch.Size([1, 3072, 128]), V shape: torch.Size([1, 3072, 128])

Iteration 1
Query shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])
Updates shape: torch.Size([1, 2, 128])
Updated slots shape: torch.Size([1, 2, 128])

Iteration 2
Query shape: torch.Size([1, 2, 128])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Attention shape: torch.Size([1, 2, 3072])
Normalized feature shape: torch.Size([1, 3072, 128])
Updates shape: torch.Size([1, 2, 128])
Updated s

shape of feature1: torch.Size([1, 128, 48, 64])
shape of feature1: torch.Size([1, 128, 48, 64])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Normalized feature shape: torch.Size([1, 3072, 128])
Normalized feature shape: torch.Size([1, 3072, 128])
Initial slots shape: torch.Size([1, 2, 128])
Initial slots shape: torch.Size([1, 2, 128])
K shape: torch.Size([1, 3072, 128]), V shape: torch.Size([1, 3072, 128])

Iteration 1
K shape: torch.Size([1, 3072, 128]), V shape: torch.Size([1, 3072, 128])

Iteration 1
Query shape: torch.Size([1, 2, 128])
Query shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])
Attention shape: torch.Size([1, 2, 3072])
Updates shape:

shape of feature1: torch.Size([1, 128, 48, 64])
shape of feature1: torch.Size([1, 128, 48, 64])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Normalized feature shape: torch.Size([1, 3072, 128])
Initial slots shape: torch.Size([1, 2, 128])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
K shape: torch.Size([1, 3072, 128]), V shape: torch.Size([1, 3072, 128])

Iteration 1
Normalized feature shape: torch.Size([1, 3072, 128])
Query shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])
Initial slots shape: torch.Size([1, 2, 128])
Updates shape: torch.Size([1, 2, 128])
K shape: torch.Size([1, 3072, 128]), V shape: torch.Size([1, 3072, 128])

Iteration 1
Updated slots shape: torch.Size([1, 2, 128])

Iteratio

shape of feature1: torch.Size([1, 128, 48, 64])
shape of feature1: torch.Size([1, 128, 48, 64])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Normalized feature shape: torch.Size([1, 3072, 128])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Initial slots shape: torch.Size([1, 2, 128])
Normalized feature shape: torch.Size([1, 3072, 128])
K shape: torch.Size([1, 3072, 128]), V shape: torch.Size([1, 3072, 128])

Iteration 1
Query shape: torch.Size([1, 2, 128])
Initial slots shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])
K shape: torch.Size([1, 3072, 128]), V shape: torch.Size([1, 3072, 128])

Iteration 1
Updates shape: torch.Size([1, 2, 128])
Query shape: torch.Size([1, 2, 128])
Updated slots sha

shape of feature1: torch.Size([1, 128, 48, 64])
shape of feature1: torch.Size([1, 128, 48, 64])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Normalized feature shape: torch.Size([1, 3072, 128])
Initial slots shape: torch.Size([1, 2, 128])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
K shape: torch.Size([1, 3072, 128]), V shape: torch.Size([1, 3072, 128])

Iteration 1
Normalized feature shape: torch.Size([1, 3072, 128])
Query shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])
Updates shape: torch.Size([1, 2, 128])
Initial slots shape: torch.Size([1, 2, 128])
Updated slots shape: torch.Size([1, 2, 128])

Iteration 2
K shape: torch.Size([1, 3072, 128]), V shape: torch.Size([1, 3072, 128])

Iteratio

shape of feature1: torch.Size([1, 128, 48, 64])
shape of feature1: torch.Size([1, 128, 48, 64])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Normalized feature shape: torch.Size([1, 3072, 128])
Initial slots shape: torch.Size([1, 2, 128])
K shape: torch.Size([1, 3072, 128]), V shape: torch.Size([1, 3072, 128])

Iteration 1
Query shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])
Updates shape: torch.Size([1, 2, 128])
Updated slots shape: torch.Size([1, 2, 128])

Iteration 2
Query shape: torch.Size([1, 2, 128])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Attention shape: torch.Size([1, 2, 3072])
Normalized feature shape: torch.Size([1, 3072, 128])
Updates shape: torch.Size([1, 2, 128])
Updated s

shape of feature1: torch.Size([1, 128, 48, 64])
shape of feature1: torch.Size([1, 128, 48, 64])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Normalized feature shape: torch.Size([1, 3072, 128])
Initial slots shape: torch.Size([1, 2, 128])
K shape: torch.Size([1, 3072, 128]), V shape: torch.Size([1, 3072, 128])

Iteration 1
Query shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])
Updates shape: torch.Size([1, 2, 128])
Updated slots shape: torch.Size([1, 2, 128])

Iteration 2
Query shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Updates shape: torch.Size([1, 2, 128])
Updated slots shape: torch.Size([1, 2, 128])

Iteration 3
Norm

shape of feature1: torch.Size([1, 128, 48, 64])shape of feature1: torch.Size([1, 128, 48, 64])


Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Normalized feature shape: torch.Size([1, 3072, 128])
Initial slots shape: torch.Size([1, 2, 128])
K shape: torch.Size([1, 3072, 128]), V shape: torch.Size([1, 3072, 128])

Iteration 1
Query shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])
Updates shape: torch.Size([1, 2, 128])
Updated slots shape: torch.Size([1, 2, 128])

Iteration 2
Query shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])
Updates shape: torch.Size([1, 2, 128])
Updated slots shape: torch.Size([1, 2, 128])

Iteration 3
Query shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])
Updates shape: torch.Size([1, 2, 128])
Updated slots shape: torch.Size([1, 2, 128])

Flow updates shape: to

shape of feature1: torch.Size([1, 128, 48, 64])
shape of feature1: torch.Size([1, 128, 48, 64])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Normalized feature shape: torch.Size([1, 3072, 128])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Initial slots shape: torch.Size([1, 2, 128])
Normalized feature shape: torch.Size([1, 3072, 128])
K shape: torch.Size([1, 3072, 128]), V shape: torch.Size([1, 3072, 128])

Iteration 1
Query shape: torch.Size([1, 2, 128])
Initial slots shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])
K shape: torch.Size([1, 3072, 128]), V shape: torch.Size([1, 3072, 128])

Iteration 1
Updates shape: torch.Size([1, 2, 128])
Query shape: torch.Size([1, 2, 128])
Updated slots sha

shape of feature1: torch.Size([1, 128, 48, 64])
shape of feature1: torch.Size([1, 128, 48, 64])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Normalized feature shape: torch.Size([1, 3072, 128])
Initial slots shape: torch.Size([1, 2, 128])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
K shape: torch.Size([1, 3072, 128]), V shape: torch.Size([1, 3072, 128])

Iteration 1
Normalized feature shape: torch.Size([1, 3072, 128])
Query shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])
Initial slots shape: torch.Size([1, 2, 128])
Updates shape: torch.Size([1, 2, 128])
K shape: torch.Size([1, 3072, 128]), V shape: torch.Size([1, 3072, 128])

Iteration 1
Updated slots shape: torch.Size([1, 2, 128])

Iteratio

shape of feature1: torch.Size([1, 128, 48, 64])
shape of feature1: torch.Size([1, 128, 48, 64])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Normalized feature shape: torch.Size([1, 3072, 128])
Initial slots shape: torch.Size([1, 2, 128])
K shape: torch.Size([1, 3072, 128]), V shape: torch.Size([1, 3072, 128])

Iteration 1
Query shape: torch.Size([1, 2, 128])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Attention shape: torch.Size([1, 2, 3072])
Normalized feature shape: torch.Size([1, 3072, 128])
Updates shape: torch.Size([1, 2, 128])
Updated slots shape: torch.Size([1, 2, 128])

Iteration 2
Initial slots shape: torch.Size([1, 2, 128])
Query shape: torch.Size([1, 2, 128])
K shape: torch.Size([1, 3072, 128]), V shape:

shape of feature1: torch.Size([1, 128, 48, 64])
shape of feature1: torch.Size([1, 128, 48, 64])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Normalized feature shape: torch.Size([1, 3072, 128])
Initial slots shape: torch.Size([1, 2, 128])
K shape: torch.Size([1, 3072, 128]), V shape: torch.Size([1, 3072, 128])

Iteration 1
Query shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Updates shape: torch.Size([1, 2, 128])
Updated slots shape: torch.Size([1, 2, 128])

Iteration 2
Normalized feature shape: torch.Size([1, 3072, 128])
Query shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])
Initial slots shape: torch.Size([1, 2, 128])
Upd

shape of feature1: torch.Size([1, 128, 48, 64])
shape of feature1: torch.Size([1, 128, 48, 64])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Normalized feature shape: torch.Size([1, 3072, 128])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Normalized feature shape: torch.Size([1, 3072, 128])
Initial slots shape: torch.Size([1, 2, 128])
K shape: torch.Size([1, 3072, 128]), V shape: torch.Size([1, 3072, 128])

Iteration 1
Initial slots shape: torch.Size([1, 2, 128])
Query shape: torch.Size([1, 2, 128])
K shape: torch.Size([1, 3072, 128]), V shape: torch.Size([1, 3072, 128])

Iteration 1
Attention shape: torch.Size([1, 2, 3072])
Query shape: torch.Size([1, 2, 128])
Updates shape: torch.Size([1, 2, 128])
Attention shape: 

shape of feature1: torch.Size([1, 128, 48, 64])
shape of feature1: torch.Size([1, 128, 48, 64])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Normalized feature shape: torch.Size([1, 3072, 128])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Initial slots shape: torch.Size([1, 2, 128])
Normalized feature shape: torch.Size([1, 3072, 128])
K shape: torch.Size([1, 3072, 128]), V shape: torch.Size([1, 3072, 128])

Iteration 1
Query shape: torch.Size([1, 2, 128])
Initial slots shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])
K shape: torch.Size([1, 3072, 128]), V shape: torch.Size([1, 3072, 128])

Iteration 1
Updates shape: torch.Size([1, 2, 128])
Query shape: torch.Size([1, 2, 128])
Updated slots sha

shape of feature1: torch.Size([1, 128, 48, 64])
shape of feature1: torch.Size([1, 128, 48, 64])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Normalized feature shape: torch.Size([1, 3072, 128])
Initial slots shape: torch.Size([1, 2, 128])
K shape: torch.Size([1, 3072, 128]), V shape: torch.Size([1, 3072, 128])

Iteration 1
Query shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])
Updates shape: torch.Size([1, 2, 128])
Updated slots shape: torch.Size([1, 2, 128])

Iteration 2
Query shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])
Updates shape: torch.Size([1, 2, 128])
Updated slots shape: torch.Size([1, 2, 128])

Iteration 3
Query shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])
Updates shape: torch.Size([1, 2, 128])
Updated slots shape: torch.Size([1, 2, 128])

Flow updates shape: to

shape of feature1: torch.Size([1, 128, 48, 64])
shape of feature1: torch.Size([1, 128, 48, 64])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Normalized feature shape: torch.Size([1, 3072, 128])
Initial slots shape: torch.Size([1, 2, 128])
K shape: torch.Size([1, 3072, 128]), V shape: torch.Size([1, 3072, 128])

Iteration 1
Query shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])
Updates shape: torch.Size([1, 2, 128])
Updated slots shape: torch.Size([1, 2, 128])

Iteration 2
Query shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])
Updates shape: torch.Size([1, 2, 128])
Updated slots shape: torch.Size([1, 2, 128])

Iteration 3
Query shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])
Updates shape: torch.Size([1, 2, 128])
Updated slots shape: torch.Size([1, 2, 128])

Flow updates shape: to

shape of feature1: torch.Size([1, 128, 48, 64])shape of feature1: torch.Size([1, 128, 48, 64])


Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Normalized feature shape: torch.Size([1, 3072, 128])
Initial slots shape: torch.Size([1, 2, 128])
K shape: torch.Size([1, 3072, 128]), V shape: torch.Size([1, 3072, 128])

Iteration 1
Query shape: torch.Size([1, 2, 128])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Attention shape: torch.Size([1, 2, 3072])
Normalized feature shape: torch.Size([1, 3072, 128])
Updates shape: torch.Size([1, 2, 128])
Updated slots shape: torch.Size([1, 2, 128])

Iteration 2
Query shape: torch.Size([1, 2, 128])
Initial slots shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])
Upd

shape of feature1: torch.Size([1, 128, 48, 64])shape of feature1: torch.Size([1, 128, 48, 64])


Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Normalized feature shape: torch.Size([1, 3072, 128])
Initial slots shape: torch.Size([1, 2, 128])
K shape: torch.Size([1, 3072, 128]), V shape: torch.Size([1, 3072, 128])

Iteration 1
Query shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Updates shape: torch.Size([1, 2, 128])
Normalized feature shape: torch.Size([1, 3072, 128])
Updated slots shape: torch.Size([1, 2, 128])

Iteration 2
Query shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])
Initial slots shape: torch.Size([1, 2, 128])
Upd

shape of feature1: torch.Size([1, 128, 48, 64])
shape of feature1: torch.Size([1, 128, 48, 64])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Normalized feature shape: torch.Size([1, 3072, 128])
Initial slots shape: torch.Size([1, 2, 128])
K shape: torch.Size([1, 3072, 128]), V shape: torch.Size([1, 3072, 128])

Iteration 1

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Query shape: torch.Size([1, 2, 128])
Normalized feature shape: torch.Size([1, 3072, 128])
Attention shape: torch.Size([1, 2, 3072])
Updates shape: torch.Size([1, 2, 128])
Initial slots shape: torch.Size([1, 2, 128])
Updated slots shape: torch.Size([1, 2, 128])

Iteration 2
K shape: torch.Size([1, 3072, 128]), V shape: torch.Size([1, 3072, 128])

Iteratio

shape of feature1: torch.Size([1, 128, 48, 64])
shape of feature1: torch.Size([1, 128, 48, 64])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Normalized feature shape: torch.Size([1, 3072, 128])
Initial slots shape: torch.Size([1, 2, 128])
K shape: torch.Size([1, 3072, 128]), V shape: torch.Size([1, 3072, 128])

Iteration 1
Query shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Updates shape: torch.Size([1, 2, 128])
Normalized feature shape: torch.Size([1, 3072, 128])
Updated slots shape: torch.Size([1, 2, 128])

Iteration 2
Query shape: torch.Size([1, 2, 128])
Initial slots shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])
Upd

shape of feature1: torch.Size([1, 128, 48, 64])
shape of feature1: torch.Size([1, 128, 48, 64])

Forward input shapes:
feature0: torch.Size([1, 128, 48, 64])
flow: torch.Size([1, 2, 48, 64])
local_window_attn: False

Processing global attention with dims b=1, c=128, h=48, w=64
Normalized feature shape: torch.Size([1, 3072, 128])
Initial slots shape: torch.Size([1, 2, 128])
K shape: torch.Size([1, 3072, 128]), V shape: torch.Size([1, 3072, 128])

Iteration 1
Query shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])
Updates shape: torch.Size([1, 2, 128])
Updated slots shape: torch.Size([1, 2, 128])

Iteration 2
Query shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])
Updates shape: torch.Size([1, 2, 128])
Updated slots shape: torch.Size([1, 2, 128])

Iteration 3
Query shape: torch.Size([1, 2, 128])
Attention shape: torch.Size([1, 2, 3072])
Updates shape: torch.Size([1, 2, 128])
Updated slots shape: torch.Size([1, 2, 128])

Flow updates shape: to