In [None]:
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from skimage import io
import skimage
from importlib import reload
import folded_dataset
reload(folded_dataset)

In [None]:
root_dir = '/scr/zchen/datasets/morphem_70k_2.0/'

dataset_name = 'HPA'
df_path = f'{root_dir}{dataset_name}/enriched_meta.csv'
df = pd.read_csv(df_path)

In [None]:
dataset = folded_dataset.SingleCellDataset(csv_file=df_path,
                                           root_dir=root_dir, target_labels='train_test_split')


In [None]:
sample_inds = [0, 100, 200]
for ind in sample_inds:
    fig, axes = plt.subplots(1,2,figsize=(10,5))
    axes[0].imshow(io.imread(root_dir + df.iloc[ind].file_path))
    axes[1].imshow(dataset[ind][0].numpy().transpose(1,2,0)[:,:,:3])


## ResNet Feature Extraction

In [None]:
from torchvision.models import resnet18, ResNet18_Weights
from torch.utils.data import DataLoader
from tqdm import tqdm

In [None]:
device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu")

In [None]:
feature_dir = f'{root_dir}features
feature_file = 'pretrained_resnet18_features.npy'

In [None]:
weights = ResNet18_Weights.IMAGENET1K_V1
m = resnet18(weights=weights).to(device)
feature_extractor = torch.nn.Sequential(*list(m.children())[:-1]).to(device)

In [None]:
train_dataloader = DataLoader(dataset, batch_size=256, shuffle=False)

In [None]:
preprocess = weights.transforms()
all_feat = []
for images, label in tqdm(train_dataloader, total=len(train_dataloader)):
    cloned_images = images.clone()
    batch_feat = []
    for i in range(cloned_images.shape[1]):
        # Copy each channel three times 
        channel = cloned_images[:, i, :, :]
        channel = channel.unsqueeze(1)
        expanded = channel.expand(-1, 3, -1, -1)

        expanded = preprocess(expanded).to(device)
        feat_temp = feature_extractor(expanded).cpu().detach().numpy()
        batch_feat.append(feat_temp)
    batch_feat = np.concatenate(batch_feat, axis=1)
    all_feat.append(batch_feat)

In [None]:
all_feat = np.concatenate(all_feat)
all_feat = all_feat.squeeze(2).squeeze(2)
all_feat.shape

In [None]:
feature_path = f'{root_dir}features/{dataset_name}/pretrained_resnet18_features.npy'
np.save(feature_path, all_feat)