In [None]:
from datasets import load_dataset

train = load_dataset("matthieulel/galaxy10_decals", split="train")
test = load_dataset("matthieulel/galaxy10_decals", split="test")

In [7]:
train, train[0]

(Dataset({
     features: ['image', 'label'],
     num_rows: 15962
 }),
 {'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=256x256>,
  'label': 7})

In [18]:
import argparse
import datetime
import json
import numpy as np
import os
import time
from pathlib import Path

import torch
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import torchvision.datasets as datasets

import timm

# assert timm.__version__ == "0.3.2"  # version check
import timm.optim.optim_factory as optim_factory

import util.misc as misc
from util.misc import NativeScalerWithGradNormCount as NativeScaler

import models_mae

from engine_pretrain import train_one_epoch

In [5]:
mae_model = models_mae.__dict__["mae_vit_base_patch16"](img_size = 256)
mae_model

MaskedAutoencoderViT(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (norm): Identity()
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity()
        (fc2): Linear(in_features=3072, out_features=768

In [24]:
from torch.utils.data import DataLoader

transform_train = transforms.Compose([
            transforms.RandomResizedCrop(256, scale=(0.2, 1.0), interpolation=3),  # 3 is bicubic
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

def t_func(data):
    data["image"] = [transform_train(sample) for sample in data["image"]]
    return data

train = train.with_transform(t_func)

def collate(data):
    images = []
    labels = []
    for example in examples:
        images.append((example["images"]))
        labels.append(example["labels"])
    images = torch.stack(images)
    labels = torch.tensor(labels)
    return {"images": images, "labels": labels}

train_loader = DataLoader(train, batch_size=8, shuffle=True)
sample_batch = next(iter(train_loader))
sample_batch

{'image': tensor([[[[-2.0665, -1.9809, -1.9295,  ..., -1.8439, -1.5699, -1.4500],
           [-1.1932, -1.1932, -1.4672,  ..., -1.8268, -1.9467, -1.8439],
           [-0.4054, -0.4397, -1.0390,  ..., -1.4843, -1.7412, -1.9295],
           ...,
           [-1.8953, -1.2103, -0.9363,  ..., -1.6042, -1.3987, -1.3130],
           [-1.6555, -1.7240, -1.5528,  ..., -1.1418, -1.1760, -1.2788],
           [-1.3130, -1.6384, -1.8097,  ..., -1.3302, -1.4329, -1.4500]],
 
          [[-1.7731, -1.7906, -2.0182,  ..., -1.3529, -1.1604, -1.5105],
           [-1.5455, -1.4055, -1.5105,  ..., -1.5805, -1.5105, -1.3880],
           [-1.6331, -1.4405, -1.2129,  ..., -1.6155, -1.1779, -0.6352],
           ...,
           [-1.9832, -1.6155, -1.1954,  ..., -1.7906, -1.6155, -1.3179],
           [-1.6506, -1.7031, -1.3004,  ..., -1.0203, -1.1954, -1.3354],
           [-1.2304, -1.3529, -1.3704,  ..., -1.0203, -1.3004, -1.5630]],
 
          [[-1.3339, -1.3687, -1.6476,  ..., -1.5604, -1.0376, -0.9504],
    

In [25]:
mae_model(sample_batch["image"])

(tensor(2.5245, grad_fn=<DivBackward0>),
 tensor([[[-1.8495,  1.0184,  0.2936,  ..., -0.8154,  0.7890,  1.4831],
          [-1.8695,  0.1210,  0.3346,  ..., -0.7111,  0.1141,  0.6251],
          [-1.8758,  1.0836,  0.4703,  ..., -0.7200,  0.9403,  1.5909],
          ...,
          [-1.7749,  0.6751,  0.0179,  ..., -0.6966,  0.2782,  1.1669],
          [-1.8075,  0.7336,  0.2283,  ..., -0.6475,  0.3509,  1.2289],
          [-1.7645,  0.7319,  0.4257,  ..., -0.5061,  0.3681,  1.2290]],
 
         [[-1.7087,  0.9702,  0.3124,  ..., -0.7133,  0.7553,  1.5740],
          [-1.7836,  0.1173,  0.3013,  ..., -0.6925,  0.1195,  0.6781],
          [-1.9063,  0.1161,  0.3951,  ..., -0.6792,  0.2401,  0.6728],
          ...,
          [-1.6721,  0.6708, -0.0133,  ..., -0.6744,  0.2769,  1.2169],
          [-1.7025,  0.7303,  0.2006,  ..., -0.6250,  0.3524,  1.2835],
          [-1.9035,  1.2943,  0.6327,  ..., -0.5391,  1.3039,  1.6136]],
 
         [[-1.6617,  0.2217,  0.2063,  ..., -0.6532,  0.059