In [1]:
import os

In [2]:
os.environ['WANDB_API_KEY'] = '5ea9498c2870cf3e2ad58c7352ac3e8949bd2a82'

In [3]:
import sys

In [4]:
!{sys.executable} -m pip install hydra-core wandb transformers datasets opencv-python --force-reinstall

Collecting hydra-core
  Using cached hydra_core-1.2.0-py3-none-any.whl (151 kB)
Collecting wandb
  Using cached wandb-0.12.21-py2.py3-none-any.whl (1.8 MB)
Collecting transformers
  Using cached transformers-4.20.1-py3-none-any.whl (4.4 MB)
Collecting datasets
  Using cached datasets-2.4.0-py3-none-any.whl (365 kB)
Collecting opencv-python
  Using cached opencv_python-4.6.0.66-cp37-abi3-macosx_11_0_arm64.whl (30.0 MB)
Collecting packaging
  Using cached packaging-21.3-py3-none-any.whl (40 kB)
Collecting antlr4-python3-runtime==4.9.*
  Using cached antlr4_python3_runtime-4.9.3-py3-none-any.whl
Collecting omegaconf~=2.2
  Using cached omegaconf-2.2.2-py3-none-any.whl (79 kB)
Collecting docker-pycreds>=0.4.0
  Using cached docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting six>=1.13.0
  Using cached six-1.16.0-py2.py3-none-any.whl (11 kB)
Collecting protobuf<4.0dev,>=3.12.0
  Using cached protobuf-3.20.1-py2.py3-none-any.whl (162 kB)
Collecting setuptools

In [6]:
import os
from hydra import compose, initialize
from hydra.core.config_store import ConfigStore
from hydra.utils import instantiate
from torch.utils.data import Dataset
from transformers import ResNetForImageClassification

import src.utils as utils
from src.conf import Config, DataConfig
from src.data import DataHandler
from src.data.dataset import CellPaintingDatasetCached
from src.model import ConvNext, Dummy, ModelTrainer, ResNet, ViT, DeiT

os.chdir("/Users/maciej.filanowicz/CellPainting/src")
initialize(version_base="1.2", config_path="conf", job_name="test_app")
cfg = compose(config_name="config", return_hydra_config=True)
dataset_config = instantiate(cfg.dataset)
train_config = instantiate(cfg.train)

In [7]:
from  transformers import SchedulerType,ResNetForImageClassification,ResNetConfig
import torch

In [8]:
from torchvision import transforms

In [9]:
# transforms = transforms.Compose([transforms.ToTensor(),
#                     transforms.Normalize(mean=[163.24,  536.39, 425.26, 581.64],std=[ 204.27, 1386.95, 917.2 , 519.7 ]),
#                     transforms.Resize(224)])

In [10]:
from typing import Any, Dict, List, Tuple, Type, Union

import cv2
import numpy as np
import numpy.typing as npt
from numpy.lib.stride_tricks import as_strided
from tifffile import imread

In [11]:
def view_as_blocks(arr_in: npt.NDArray[Any], block_shape: Tuple[int, int, int] = (540,540,4)) -> Any:
    if arr_in.shape[0] % block_shape[0] or arr_in.shape[1] % block_shape[1] or arr_in.shape[2] != block_shape[2]:
        raise ValueError("Incompatible block shape!")
    bs = np.array(block_shape)
    arr_shape = np.array(arr_in.shape)
    new_shape = tuple(arr_shape // bs) + tuple(bs)
    new_strides = tuple(arr_in.strides * bs) + arr_in.strides
    arr_out = as_strided(arr_in, shape=new_shape, strides=new_strides)
    arr_out = np.squeeze(arr_out)
    return arr_out

In [12]:

from torch.utils.data import Dataset

import pandas as pd

metadata = pd.read_csv("../data/processed/meta_data.csv")
# train_dataset = CellPaintingDatasetCached(
#     metadata[metadata.folder_name == "train"], dataset_config, dataset_config.train_transforms
# )
test_dataset = CellPaintingDatasetCached(
    metadata[metadata.folder_name == "test"], dataset_config,None
)
# val_dataset = CellPaintingDatasetCached(
#     metadata[metadata.folder_name == "val"], dataset_config, dataset_config.test_transforms
# )

In [13]:
test_images = [test_dataset.__getitem__(0)['pixel_values'],test_dataset.__getitem__(1)['pixel_values']]
test_images = np.stack(test_images)

In [22]:
from transformers.modeling_outputs import ImageClassifierOutput

In [14]:
from torch.nn.functional import softmax

In [15]:
config = ResNetConfig(num_labels=9,
                      pretrained_model_name = '../models/resnet18/pytorch_model.bin',
                      label2id=dict(dataset_config.label2id),
                      id2label=dict(dataset_config.id2label),
                      num_channels=4,
                      depths=[2,2,2,2],
                      downsample_in_first_stage= False,
                      embedding_size=64,
                      hidden_sizes=[64,128,256,512],
                      layer_type='basic',
                      device='cuda' if torch.cuda.is_available() else 'cpu')

class CellPainingModel:
    def __init__(self, config: ResNetConfig, transforms: transforms = None) -> None:
        self.model = ResNetForImageClassification(config)
        self.model.load_state_dict(torch.load(config.pretrained_model_name,map_location=torch.device(config.device)))
        self._transforms = transforms

    def _patch_images(self,images):
        patches = np.array([view_as_blocks(i) for i in images])
        return patches

    def transform(self,images):
        images = torch.stack([self._transforms(i) for i in images])
        return images

    def __call__(self, pixel_values):
        images = self._patch_images(pixel_values)
        input_shape = images.shape
        images = images.reshape(-1,540,540,4)
        if self._transforms:
            images = self.transform(images)
        output = self.model(pixel_values=images,labels=None)
        output_shape = (input_shape[0],input_shape[1]*input_shape[2])
        logits = output.logits.reshape((*output_shape,-1))
        mean_logits = softmax(torch.mean(logits,axis=1),-1).detach().cpu().numpy()
        return mean_logits

TypeError: Object of type Compose is not JSON serializable

In [16]:
model = CellPainingModel(config,transforms)

In [17]:
model(test_images)

array([[4.5147352e-02, 3.6047108e-02, 6.6713601e-01, 5.8229011e-06,
        2.1507222e-02, 6.2058074e-03, 1.6777913e-01, 5.1966440e-02,
        4.2051310e-03],
       [2.7855344e-03, 1.1297154e-05, 9.9607098e-01, 5.3346830e-06,
        3.5054658e-05, 1.0283979e-03, 2.2110253e-05, 3.3837998e-05,
        7.4594254e-06]], dtype=float32)