In [2]:
 !pip install torch torchvision datasets --upgrade
 !pip install transformers timm tqdm
 !pip install wandb   # Optional for tracking

Collecting torch
  Downloading torch-2.7.1-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (29 kB)
Collecting torchvision
  Downloading torchvision-0.22.1-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (6.1 kB)
Collecting datasets
  Downloading datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.6.77 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.6.77 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.6.80 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.5.1.17 (from torch)
  Downloading nvidia_cudnn_cu12-9.5.1.17-py3-none

In [3]:
!pip install -q scikit-learn matplotlib

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
import os

WORK_DIR = "/content/drive/MyDrive/rvlcdip_classifier"
os.makedirs(WORK_DIR, exist_ok=True)

In [7]:
from datasets import load_dataset

dataset = load_dataset("aharley/rvl_cdip", cache_dir="/content/drive/MyDrive/rvl_cdip_cache", trust_remote_code=True)


rvl-cdip.tar.gz:  17%|#6        | 7.91G/46.7G [00:00<?, ?B/s]

train.txt:   0%|          | 0.00/13.7M [00:00<?, ?B/s]

test.txt:   0%|          | 0.00/1.72M [00:00<?, ?B/s]

val.txt:   0%|          | 0.00/1.72M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/320000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/40000 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/40000 [00:00<?, ? examples/s]

Loading dataset shards:   0%|          | 0/77 [00:00<?, ?it/s]

In [8]:
print(dataset)
print(dataset['train'][0])

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 320000
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 40000
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 40000
    })
})
{'image': <PIL.Image.Image image mode=L size=762x1000 at 0x7FD27E35F110>, 'label': 11}


In [9]:
labels = dataset['train'].features['label'].names
id2label = {i: l for i, l in enumerate(labels)}
label2id = {l: i for i, l in enumerate(labels)}

print(id2label)


{0: 'letter', 1: 'form', 2: 'email', 3: 'handwritten', 4: 'advertisement', 5: 'scientific report', 6: 'scientific publication', 7: 'specification', 8: 'file folder', 9: 'news article', 10: 'budget', 11: 'invoice', 12: 'presentation', 13: 'questionnaire', 14: 'resume', 15: 'memo'}


In [10]:
from torchvision import transforms

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Lambda(lambda img: img.convert("RGB")),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Lambda(lambda img: img.convert("RGB")),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])


In [11]:
from torch.utils.data import Dataset

class RVLCDIPDataset(Dataset):
    def __init__(self, hf_dataset, transform=None):
        self.dataset = hf_dataset
        self.transform = transform

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item['image']
        label = item['label']

        if self.transform:
            image = self.transform(image)

        return image, label

    def __len__(self):
        return len(self.dataset)


In [12]:
train_ds = RVLCDIPDataset(dataset['train'], transform=train_transform)
val_ds = RVLCDIPDataset(dataset['validation'], transform=val_transform)
test_ds = RVLCDIPDataset(dataset['test'], transform=val_transform)


In [13]:
from torch.utils.data import DataLoader

BATCH_SIZE = 64

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)


In [11]:
!mkdir -p /content/drive/MyDrive/rvl_cdip_cache


In [12]:
!cp -r /root/.cache/huggingface/datasets /content/drive/MyDrive/rvl_cdip_cache/

In [13]:
!ls -lh /content/drive/MyDrive/rvl_cdip_cache/datasets


total 8.0K
drwx------ 3 root root 4.0K Jun  9 02:05 aharley___rvl_cdip
drwx------ 2 root root 4.0K Jun  9 02:05 downloads
-rw------- 1 root root    0 Jun  9 02:05 _root_.cache_huggingface_datasets_aharley___rvl_cdip_default_1.0.0_b9e57261da1599f0f1a4a03abe42d47bfa600a4fd3d3297cd0d0d45309085b23.lock


In [14]:
!du -sh /content/drive/MyDrive/rvl_cdip_cache/datasets


46G	/content/drive/MyDrive/rvl_cdip_cache/datasets


In [1]:
from google.colab import drive
drive.mount('/content/drive')

!rm -rf /root/.cache/huggingface/datasets
!mkdir -p /root/.cache/huggingface
!cp -r /content/drive/MyDrive/rvl_cdip_cache/datasets /root/.cache/huggingface/


Mounted at /content/drive


In [14]:
import torch
import torch.nn as nn
from torchvision import models

NUM_CLASSES = 16  # For RVL-CDIP

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load EfficientNet-B0
model = models.efficientnet_b0(pretrained=True)

# Replace the classifier
model.classifier[1] = nn.Linear(model.classifier[1].in_features, NUM_CLASSES)

model = model.to(device)




Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-7f5810bc.pth


100%|██████████| 20.5M/20.5M [00:00<00:00, 100MB/s]


In [15]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)


In [16]:
EPOCHS = 5  # Start small

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {running_loss:.4f}, Accuracy: {100 * correct / total:.2f}%")


Epoch 1/5, Loss: 3253.0384, Accuracy: 80.63%
Epoch 2/5, Loss: 1948.8120, Accuracy: 88.29%
Epoch 3/5, Loss: 1571.9310, Accuracy: 90.45%
Epoch 4/5, Loss: 1318.3369, Accuracy: 91.86%
Epoch 5/5, Loss: 1115.5758, Accuracy: 93.07%


In [17]:
model.eval()
val_correct = 0
val_total = 0

with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        val_total += labels.size(0)
        val_correct += (predicted == labels).sum().item()

print(f"Validation Accuracy: {100 * val_correct / val_total:.2f}%")

Validation Accuracy: 91.06%


In [19]:
def evaluate(model, test_loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in test_loader:
            try:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
            except Exception as e:
                print(f"Skipped batch due to error: {e}")
    print(f"Test Accuracy: {100 * correct / total:.2f}%")

evaluate(model, test_loader)




UnidentifiedImageError: Caught UnidentifiedImageError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/_utils/fetch.py", line 52, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
            ~~~~~~~~~~~~^^^^^
  File "<ipython-input-11-907c0577df25>", line 9, in __getitem__
    item = self.dataset[idx]
           ~~~~~~~~~~~~^^^^^
  File "/usr/local/lib/python3.11/dist-packages/datasets/arrow_dataset.py", line 2777, in __getitem__
    return self._getitem(key)
           ^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/datasets/arrow_dataset.py", line 2762, in _getitem
    formatted_output = format_table(
                       ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/datasets/formatting/formatting.py", line 653, in format_table
    return formatter(pa_table, query_type=query_type)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/datasets/formatting/formatting.py", line 406, in __call__
    return self.format_row(pa_table)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/datasets/formatting/formatting.py", line 455, in format_row
    row = self.python_features_decoder.decode_row(row)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/datasets/formatting/formatting.py", line 223, in decode_row
    return self.features.decode_example(row, token_per_repo_id=self.token_per_repo_id) if self.features else row
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/datasets/features/features.py", line 2100, in decode_example
    return {
           ^
  File "/usr/local/lib/python3.11/dist-packages/datasets/features/features.py", line 2101, in <dictcomp>
    column_name: decode_nested_example(feature, value, token_per_repo_id=token_per_repo_id)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/datasets/features/features.py", line 1414, in decode_nested_example
    return schema.decode_example(obj, token_per_repo_id=token_per_repo_id) if obj is not None else None
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/datasets/features/image.py", line 186, in decode_example
    image = PIL.Image.open(BytesIO(bytes_))
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/PIL/Image.py", line 3572, in open
    raise UnidentifiedImageError(msg)
PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x7fd1d6a9ade0>


In [20]:
# Save the model state dict after training
torch.save(model.state_dict(), "dark-cat.pth")

In [22]:
# Example values (replace these with your actual variables)
epoch = 5
train_loss = 1115.5758
train_accuracy = 93.07
val_accuracy = 91.07

# Save checkpoint
torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'train_loss': train_loss,
    'train_accuracy': train_accuracy,
    'val_accuracy': val_accuracy,
}, "checkpoint.pth")


In [23]:
from google.colab import files
files.download("checkpoint.pth")


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>