In [1]:
!pip install einops
!pip install --upgrade torch torchvision torchaudio
!pip install wandb

Defaulting to user installation because normal site-packages is not writeable
Collecting einops
  Downloading einops-0.6.0-py3-none-any.whl (41 kB)
Installing collected packages: einops
Successfully installed einops-0.6.0
Defaulting to user installation because normal site-packages is not writeable
Collecting torch
  Downloading torch-2.0.0-cp39-cp39-win_amd64.whl (172.3 MB)
Collecting torchvision
  Downloading torchvision-0.15.1-cp39-cp39-win_amd64.whl (1.2 MB)
Collecting torchaudio
  Downloading torchaudio-2.0.1-cp39-cp39-win_amd64.whl (2.1 MB)
Installing collected packages: torch, torchvision, torchaudio
Successfully installed torch-2.0.0 torchaudio-2.0.1 torchvision-0.15.1




Defaulting to user installation because normal site-packages is not writeable
Collecting wandb
  Downloading wandb-0.14.0-py3-none-any.whl (2.0 MB)
Collecting setproctitle
  Downloading setproctitle-1.3.2-cp39-cp39-win_amd64.whl (11 kB)
Collecting GitPython!=3.1.29,>=1.0.0
  Downloading GitPython-3.1.31-py3-none-any.whl (184 kB)
Collecting docker-pycreds>=0.4.0
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting pathtools
  Downloading pathtools-0.1.2.tar.gz (11 kB)
Collecting sentry-sdk>=1.0.0
  Downloading sentry_sdk-1.18.0-py2.py3-none-any.whl (194 kB)
Collecting gitdb<5,>=4.0.1
  Downloading gitdb-4.0.10-py3-none-any.whl (62 kB)
Collecting smmap<6,>=3.0.1
  Downloading smmap-5.0.0-py3-none-any.whl (24 kB)
Collecting urllib3<1.27,>=1.21.1
  Downloading urllib3-1.26.15-py2.py3-none-any.whl (140 kB)
Building wheels for collected packages: pathtools
  Building wheel for pathtools (setup.py): started
  Building wheel for pathtools (setup.py): finished with status

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
conda-repo-cli 1.0.4 requires pathlib, which is not installed.


In [2]:
import os
os.environ['KAGGLE_USERNAME'] = ##
os.environ['KAGGLE_KEY'] = ##

SyntaxError: invalid syntax (1668384032.py, line 2)

In [None]:
!kaggle competitions download -c vesuvius-challenge-ink-detection
!unzip /content/vesuvius-challenge-ink-detection.zip

In [None]:
import wandb
!wandb login

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

## Pretraining a Masked Auto-Encoder

Slices of tomography have some kind of structure that isn't obvious to us. By pre-training an autoencoder to reconstruct these slices, we obtain a feature representation that captures some of this internal structure, and can be leveraged for downstream tasks. I theorize that even the outer layers, not useful for ink detection, may still be useful enough for unsupervised pretraining. For downstream supervised learning, we use this pretrained MAE to extract the "features" for each slice, and can then train another model on top of these representations.

In [None]:
import PIL.Image as Image
import torch.utils.data as data
from typing import List, Tuple
from pathlib import Path
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from einops import rearrange, repeat
from tqdm.auto import tqdm

In [None]:
class Patch3DDataset(torch.utils.data.IterableDataset):
    def __init__(
            self,
            fragments: List[Path],
            patch_shape: Tuple[int, int, int],
            buffer_size: int = 50000
    ):
        self.fragments = sorted(map(lambda path: path.resolve(), fragments))
        self.z_dim, self.y_dim, self.x_dim = patch_shape
        # self.load_inklabels = load_inklabels
        # self.filter_edge_pixels = filter_edge_pixels

        self.transform = transforms.Compose([
            transforms.Lambda(lambda patch: patch / 65535.0),
            transforms.Lambda(lambda patch: torch.tensor(patch, dtype=torch.float32)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip()
        ])

        self.buffer_size = buffer_size

    def __iter__(self):
        # buffer used for pseudo-shuffing so the patches aren't returned completely sequentially
        buffer = []

        def add_to_buffer(item):
            if len(buffer) < self.buffer_size:
                buffer.append(item)
            else:
                idx = random.randint(0, self.buffer_size - 1)
                buffer[idx] = item

        def yield_random_from_buffer():
            idx = random.randint(0, len(buffer) - 1)
            return buffer.pop(idx)

        # Load sequentially
        while True: # infinite "streaming" dataset, use the number of steps to trigger stopping
            for fragment_id, fragment_path in enumerate(self.fragments):
                fragment_path = fragment_path.resolve()  # absolute path
                mask = np.array(Image.open(str(fragment_path / "mask.png")).convert("1"))
                fragment_y, fragment_x = mask.shape
                y_pad = self.y_dim - (fragment_y % self.y_dim)
                x_pad = self.x_dim - (fragment_x % self.x_dim)
                mask = np.pad(mask, pad_width = ((0, y_pad), (0, x_pad)), mode='constant', constant_values=0)

                surface_volume_paths = sorted(
                    (fragment_path / "surface_volume").rglob("*.tif")
                )

                for z_idx in range(0, len(surface_volume_paths), self.z_dim):
                    images = [
                        np.array(Image.open(fn)) for fn in surface_volume_paths[z_idx:z_idx + self.z_dim]
                    ]
                    image_stack = np.stack(images, axis=0)

                    image_stack = np.pad(image_stack, pad_width = ((0, 0), (0, y_pad), (0, x_pad)), mode='constant', constant_values=0)
                    for y in range(0, mask.shape[0], self.y_dim):
                        for x in range(0, mask.shape[1], self.x_dim):
                            mask_chunk = mask[y:y+self.y_dim, x:x + self.x_dim]
                            patch = image_stack[:, y:y+self.y_dim, x:x+self.x_dim]
                            mask_mean = np.mean(mask_chunk)
                            # only train on patches that are >= 30% data and a full patch
                            if mask_mean < 0.3 or patch.shape[0] != self.z_dim:
                                continue
                            else:
                                add_to_buffer((patch, mask_chunk))
                                # if buffer full, then we are ready to randomly return a patch
                                if len(buffer) == self.buffer_size:
                                    random_patch, random_mask_chunk = yield_random_from_buffer()
                                    yield self.transform(random_patch), random_mask_chunk

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, n_heads, embed_dim, d_qkv, dropout):
        super().__init__()
        self.d_qkv = d_qkv
        self.norm1 = nn.LayerNorm(embed_dim)
        self.to_qkv = nn.Linear(embed_dim, d_qkv * n_heads * 3, bias=False)
        self.attn_dropout_p = dropout
        self.out_proj = nn.Linear(d_qkv * n_heads, embed_dim, bias=False)
        self.resid_dropout1 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.GELU(),
            nn.Linear(4 * embed_dim, embed_dim)
        )
        self.resid_dropout2 = nn.Dropout(dropout)

    def forward(self, X):
        normed1 = self.norm1(X)
        Q, K, V = rearrange(self.to_qkv(normed1), "b l (h ddd) -> b h l ddd", ddd=(3 * self.d_qkv)).chunk(3, dim=-1) # b, h, l, d_attn
        attn = F.scaled_dot_product_attention(Q, K, V, dropout_p=self.attn_dropout_p, is_causal=False).transpose(1, 2) # b, l, h, d_attn
        attn_out = X + self.resid_dropout1(self.out_proj(attn.flatten(2, 3)))
        normed2 = self.norm2(attn_out)
        return attn_out + self.resid_dropout2(self.ffn(normed2))

In [None]:
class MAEVisionTransformer(nn.Module):
    def __init__(self, img_size, patch_size, n_channels, mask_prob,
                 encoder_depth, decoder_depth, n_heads, embed_dim, d_qkv, dropout=0.1):
        super().__init__()
        assert img_size % patch_size == 0, "Image must divide evenly into mini-patches."
        self.mask_prob = mask_prob
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_channels = n_channels
        self.patch_emb = nn.Linear(patch_size**2 * n_channels, embed_dim)
        self.pos_embs = nn.Parameter(torch.zeros(((img_size // patch_size)**2, embed_dim)) * 0.1)
        self.mask_token = nn.Parameter(torch.randn((embed_dim,)) * 0.1)
        self.encoder = nn.Sequential(*[
            TransformerBlock(n_heads, embed_dim, d_qkv, dropout) for _ in range(encoder_depth)
        ])
        self.decoder = nn.Sequential(*[
            TransformerBlock(n_heads, embed_dim, d_qkv, dropout) for _ in range(decoder_depth)
        ])
        self.output_head = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(4 * embed_dim, patch_size**2 * n_channels)
        )

    def forward(self, X, training=True):
        # Embed patches, apply positional embeddings
        patches = rearrange(X, "b c (h1 h2) (w1 w2) -> b (h1 w1) (c h2 w2)", h2=self.patch_size, w2=self.patch_size)
        seq = self.patch_emb(patches) + self.pos_embs.unsqueeze(0)
        b, l, d = seq.shape

        # If training, keep subsample of patches
        if training:
            with torch.no_grad():
                n_to_keep = int(np.floor(l * (1 - self.mask_prob)))
                perm = torch.randperm(l)
                unmasked_idxs = perm[:n_to_keep]
                masked_idxs = perm[n_to_keep:]
                inv_perm = torch.argsort(perm)
            seq = seq[:, perm[:n_to_keep], :]

        # Apply encoder
        for block in self.encoder:
            seq = block(seq)

        if not training:
            return seq

        # Add back masked patches, positional embeddings, unshuffle
        mask_chunk = repeat(self.mask_token, "d -> b l d", b=seq.shape[0], l=l - n_to_keep)
        pos_emb_chunk = self.pos_embs[perm[n_to_keep:], :].unsqueeze(0)
        mask_chunk = mask_chunk + pos_emb_chunk
        seq = torch.cat([seq, mask_chunk], dim=1)
        seq = seq[:, inv_perm, :]

        # Apply decoder
        for block in self.decoder:
            seq = block(seq)

        # Output
        output =  self.output_head(seq) # batch, n_patches, patch_size
        return masked_idxs, output

In [None]:
vit = MAEVisionTransformer(64, 16, 5, 0.75, 16, 8, 12, 768, 64, dropout=0.1)

In [None]:
base_path = Path("/content/")
train_path = base_path / "train"
all_fragments = sorted([f.name for f in train_path.iterdir()])
print("All fragments:", all_fragments)
train_fragments = [train_path / fragment_name for fragment_name in all_fragments]
train_dset = Patch3DDataset(fragments=train_fragments, patch_shape=(5, 64, 64))
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=128, shuffle=False, pin_memory=True)

In [None]:
wandb.init(
    project="scroll-transformer",
    config={"lr": 1e-4, "wd": 1e-3}
)

CKPT_SAVE_DIR = "/content/drive/MyDrive/scroll_vit_checkpoints/"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
vit.to(device)
criterion = nn.L1Loss()
optimizer = torch.optim.AdamW(vit.parameters(), lr=1e-4, weight_decay=1e-3)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

max_steps = 25000
steps = 0
running_loss = 0.0
print_every = 50
while True:
    for it, batch in enumerate(train_loader):
        steps += 1
        X = batch[0].to(device)

        optimizer.zero_grad()
        masked_idxs, out_patches = vit(X)
        preds = out_patches[:, masked_idxs, :]
        original_patches = rearrange(X, "b c (h1 h2) (w1 w2) -> b (h1 w1) (c h2 w2)", h2=vit.patch_size, w2=vit.patch_size)
        targets = original_patches[:, masked_idxs, :]
        loss = criterion(preds, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        wandb.log({"train_loss": loss.item(), "lr": scheduler.get_last_lr()[0]})
        if steps % print_every == 0:
            print(f"STEP {steps} | LOSS: {running_loss / print_every:.3f} | LR: {scheduler.get_last_lr()[0]:.6f}")
            running_loss = 0.0
            scheduler.step()
        if steps % 5000 == 0:
            torch.save(vit, CKPT_SAVE_DIR + "vit-" + str(steps) + ".ckpt")
        if steps >= max_steps:
            break
    if steps >= max_steps:
        break

torch.save(vit, CKPT_SAVE_DIR + "vit-final.ckpt")