In [None]:
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import numpy as np
import os
from pathlib import Path
import sys


tri_path = os.environ.get('TRIPATH_DIR')
if tri_path and tri_path not in sys.path:
    sys.path.append(tri_path)
from models.feature_extractor import swin3d_b
from tqdm import tqdm


In [10]:
class PatchDataset(Dataset):
    def __init__(self, paths):
        self.paths = paths

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        patch = np.load(path)
        patch = torch.from_numpy(patch).float()
        patch = patch.unsqueeze(0)
        patch = patch.repeat(3,1,1,1)
        return patch, str(path)

In [11]:
class PatchEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = swin3d_b()
        self.encoder.load_weights()

    def forward(self, x):
        x = self.encoder(x)
        return x

In [12]:
data_dir = Path(os.environ['DATA_DIR'])
patch_dir = data_dir / "patches"
paths = list(patch_dir.rglob("*.npy"))
dataset = PatchDataset(paths)
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
encoder = PatchEncoder()

Loading pretrained video weights


In [17]:
for i, (patch, path) in enumerate(tqdm(dataloader)):
    with torch.no_grad():
        patch = encoder(patch)
    for p, pth in zip(patch, path):
        pth = Path(pth)
        save_path = pth.parent.parent / "embeddings" / (pth.parent.name + "_" + pth.stem + ".pt")
        if not save_path.parent.exists():
            save_path.parent.mkdir()
        torch.save(p, save_path)


100%|██████████| 39/39 [01:37<00:00,  2.50s/it]
