In [2]:
from natten import (
  NeighborhoodAttention1D,
  NeighborhoodAttention2D,
  NeighborhoodAttention3D,
)
na1d = NeighborhoodAttention1D(dim=128, kernel_size=7, dilation=3, num_heads=4)
na2d = NeighborhoodAttention2D(dim=128, kernel_size=7, dilation=3, num_heads=4)
na3d = NeighborhoodAttention3D(dim=128, kernel_size=7, dilation=3, num_heads=4)

In [3]:
na1d.extra_repr()

'head_dim=32, num_heads=4, kernel_size=7, dilation=3, has_bias=True'

In [10]:
from pathlib import Path
import pandas as pd

datapath = '/mnt/datadrive'
x = list(filter(lambda x: x.suffix=='.png', Path(datapath).rglob('*')))
df = pd.DataFrame(x, columns=['fullpath'])
df.head(5)

Unnamed: 0,fullpath
0,/mnt/datadrive/asos_dataset/tshirts_orig_biges...
1,/mnt/datadrive/asos_dataset/tshirts_orig_biges...
2,/mnt/datadrive/asos_dataset/tshirts_orig_biges...
3,/mnt/datadrive/asos_dataset/tshirts_orig_biges...
4,/mnt/datadrive/asos_dataset/tshirts_orig_biges...


In [15]:
df['item_idx'] = df['fullpath'].apply(lambda x: int(x.parent.stem))
df['item_idx'].value_counts()

item_idx
0      4
471    4
465    4
466    4
467    4
      ..
276    4
277    4
278    4
279    4
99     4
Name: count, Length: 609, dtype: int64

In [24]:
df.to_csv('all_imgs.csv', index=False)

In [23]:
from PIL import Image
from matplotlib import pyplot as plt

# for j in [0, 1, 2, 3, 44, 55, 99]:
#     for i in df[df['item_idx']==j]['fullpath'].values:
#         print(i)
#         img = Image.open(i)
#         plt.imshow(img)
#         plt.show()    

In [26]:
for group_idx, group in df.groupby(by='item_idx'):
    print(group_idx)
    print(group)
    break

0
                                            fullpath  item_idx
0  /mnt/datadrive/asos_dataset/tshirts_orig_biges...         0
1  /mnt/datadrive/asos_dataset/tshirts_orig_biges...         0
2  /mnt/datadrive/asos_dataset/tshirts_orig_biges...         0
3  /mnt/datadrive/asos_dataset/tshirts_orig_biges...         0


In [7]:
import torch
from torch.utils.data import DataLoader, Dataset

from PIL import Image
from matplotlib import pyplot as plt
import numpy as np
from collections import defaultdict

class SyntheticTryonDataset(Dataset):
    def __init__(self, num_samples, image_size=(64,64), pose_size=(18, 2)):
        """
        Args:
            num_samples (int): Number of samples in the dataset.
            image_size (tuple): The height and width of the images (height, width).
            pose_size (tuple): The size of the pose tensors (default: (18, 2)).
        """
        
        self.df = pd.read_csv('/home/roman/tryondiffusion_implementation/tryondiffusion_danny/all_imgs.csv')
        # self.item_ids = np.unique(self.df['item_idx'].values)
        self.items_reverse_index = {}
        for group_idx, group in df.groupby(by='item_idx'):
            self.items_reverse_index[group_idx] = group['fullpath'].values        
         
        self.num_samples = num_samples
        self.image_size = image_size
        self.pose_size = pose_size

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        
        record = self.df.iloc[idx]
        item_idx = record['item_idx']
        
        
        person_image = 
        # person_image = torch.randn(3, *self.image_size)
        # ca_image = torch.randn(3, *self.image_size)
        # garment_image = torch.randn(3, *self.image_size)
        # person_pose = torch.randn(*self.pose_size)
        # garment_pose = torch.randn(*self.pose_size)

        sample = {
            "person_images": person_image,
            "ca_images": ca_image,
            "garment_images": garment_image,
            "person_poses": person_pose,
            "garment_poses": garment_pose,
        }

        return sample


def tryondiffusion_collate_fn(batch):
    return {
        "person_images": torch.stack([item["person_images"] for item in batch]),
        "ca_images": torch.stack([item["ca_images"] for item in batch]),
        "garment_images": torch.stack([item["garment_images"] for item in batch]),
        "person_poses": torch.stack([item["person_poses"] for item in batch]),
        "garment_poses": torch.stack([item["garment_poses"] for item in batch]),
    }

ds = SyntheticTryonDataset(num_samples=100)
for i in ds:
    for k,v in i.items():
        print(k)
        print(v.shape)
        print()
    break

person_images
torch.Size([3, 64, 64])

ca_images
torch.Size([3, 64, 64])

garment_images
torch.Size([3, 64, 64])

person_poses
torch.Size([18, 2])

garment_poses
torch.Size([18, 2])

