In [26]:
import torch
import torch.nn as nn
import numpy as np
from datasets import load_dataset
from transformers import CLIPVisionModel, CLIPProcessor, TrainingArguments, Trainer
from sklearn.metrics import classification_report
from PIL import Image
import evaluate
import shutil
import random
from PIL import ImageOps, ImageEnhance

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

total, used, free = shutil.disk_usage(".")
print(f"Disk space: {free // (2**20)} MB free")
if free < 5 * 2**30:
    raise OSError("Not enough disk space. Please free up space before running this notebook.")

# Load dataset
data_dir = "fabric_dataset2"
dataset = load_dataset("imagefolder", data_dir=data_dir)
split = dataset["train"].train_test_split(test_size=0.2, seed=42)

# Label mapping
labels = split["train"].features["label"].names
label2id = {label: i for i, label in enumerate(labels)}
id2label = {i: label for i, label in enumerate(labels)}

# CLIP processor
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Augmentation
def apply_augmentation(image):
    if random.random() < 0.5:
        image = ImageOps.mirror(image)
    if random.random() < 0.3:
        image = image.rotate(random.choice([-15, -10, -5, 5, 10, 15]))
    if random.random() < 0.3:
        image = ImageEnhance.Brightness(image).enhance(random.uniform(0.8, 1.2))
    if random.random() < 0.3:
        image = ImageEnhance.Contrast(image).enhance(random.uniform(0.8, 1.2))
    if random.random() < 0.3:
        image = ImageOps.autocontrast(image)
    if random.random() < 0.3:
        image = ImageEnhance.Color(image).enhance(random.uniform(0.8, 1.2))
    return image

# Transform
def transform_fn(example):
    image = None
    try:
        image = example["image"]
        if not isinstance(image, Image.Image):
            image = Image.fromarray(np.array(image))
        image = image.convert("RGB")
        image = apply_augmentation(image)
        inputs = processor(images=image, return_tensors="pt")
        pixel_values = inputs["pixel_values"].squeeze(0)
        if pixel_values.shape != (3, 224, 224):
            raise ValueError("Invalid pixel_values shape")
        return {"pixel_values": pixel_values, "label": int(example["label"])}
    except Exception as e:
        print(f"Skipping image due to error: {e}")
        return None
    finally:
        if image is not None and hasattr(image, "close"):
            try:
                image.close()
            except Exception:
                pass

split = split.map(transform_fn, remove_columns=["image"], num_proc=1)
split = split.filter(lambda x: x is not None and "pixel_values" in x and "label" in x)
split["train"] = split["train"].map(lambda x: {"label": int(x["label"])})
split["test"] = split["test"].map(lambda x: {"label": int(x["label"])})

split["train"].set_format(type="torch", columns=["pixel_values", "label"])
split["test"].set_format(type="torch", columns=["pixel_values", "label"])

# --- CutMix and Mixup helpers ---
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)
    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    return bbx1, bby1, bbx2, bby2

def cutmix_data(x, y, alpha=1.0):
    ''' x: images, y: labels '''
    lam = np.random.beta(alpha, alpha)
    rand_index = torch.randperm(x.size()[0])
    y_a = y
    y_b = y[rand_index]
    bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
    x[:, :, bby1:bby2, bbx1:bbx2] = x[rand_index, :, bby1:bby2, bbx1:bbx2]
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2]))
    return x, y_a, y_b, lam

def mixup_data(x, y, alpha=1.0):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    batch_size = x.size()[0]
    index = torch.randperm(batch_size)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

# Model
class CLIPViTClassifier(nn.Module):
    def __init__(self, num_classes, use_cutmix=False, use_mixup=False, alpha=1.0):
        super().__init__()
        self.clip_vit = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
        for param in self.clip_vit.parameters():
            param.requires_grad = True
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.clip_vit.config.hidden_size, num_classes)
        self.loss_fn = nn.CrossEntropyLoss()
        self.use_cutmix = use_cutmix
        self.use_mixup = use_mixup
        self.alpha = alpha

    def forward(self, pixel_values, labels=None):
        # Only apply CutMix/Mixup during training and if labels are provided
        if self.training and labels is not None:
            if self.use_cutmix:
                pixel_values, targets1, targets2, lam = cutmix_data(pixel_values, labels, self.alpha)
            elif self.use_mixup:
                pixel_values, targets1, targets2, lam = mixup_data(pixel_values, labels, self.alpha)
            else:
                targets1, targets2, lam = labels, labels, 1.0
        else:
            targets1, targets2, lam = labels, labels, 1.0

        outputs = self.clip_vit(pixel_values=pixel_values)
        pooled_output = outputs.pooler_output
        logits = self.classifier(self.dropout(pooled_output))
        loss = None
        if labels is not None:
            if self.training and (self.use_cutmix or self.use_mixup):
                loss = lam * self.loss_fn(logits, targets1) + (1 - lam) * self.loss_fn(logits, targets2)
            else:
                loss = self.loss_fn(logits, labels)
        return (loss, logits)

# --- Choose one: use_cutmix=True or use_mixup=True ---
model = CLIPViTClassifier(num_classes=len(labels), use_cutmix=True, alpha=1.0).to(device)
# model = CLIPViTClassifier(num_classes=len(labels), use_mixup=True, alpha=1.0).to(device)

# Metrics
accuracy = evaluate.load("accuracy")
def compute_metrics(p):
    preds = np.argmax(p.predictions, axis=1)
    return accuracy.compute(predictions=preds, references=p.label_ids)

# Training arguments
args = TrainingArguments(
    output_dir="./clip-vit-fabric2",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=5,
    logging_dir="./logs",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    report_to=[], 
    max_grad_norm=0.5,
    fp16=True if torch.cuda.is_available() else False,
    remove_unused_columns=False,
    learning_rate=5e-6,
    weight_decay=0.01,
    # optimizer_class=torch.optim.AdamW,  # Removed invalid argument
    # optimizers=(None, None),  # Use default optimizer
)

optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)

# Trainer
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=split["train"],
    eval_dataset=split["test"],
    compute_metrics=compute_metrics,
)

# Train
trainer.train(resume_from_checkpoint=False)

# Evaluate
preds = trainer.predict(split["test"])
y_true = preds.label_ids
y_pred = np.argmax(preds.predictions, axis=1)
print("\n Classification Report:\n", classification_report(y_true, y_pred, target_names=labels))

Disk space: 162148 MB free


Resolving data files:   0%|          | 0/6005 [00:00<?, ?it/s]



  0%|          | 0/3000 [00:00<?, ?it/s]

{'loss': 2.4751, 'grad_norm': 73.02481079101562, 'learning_rate': 4.171666666666667e-06, 'epoch': 0.83}


  0%|          | 0/150 [00:00<?, ?it/s]

{'eval_loss': 1.2580868005752563, 'eval_accuracy': 0.6933333333333334, 'eval_runtime': 39.442, 'eval_samples_per_second': 30.424, 'eval_steps_per_second': 3.803, 'epoch': 1.0}
{'loss': 1.6705, 'grad_norm': 73.44835662841797, 'learning_rate': 3.3383333333333333e-06, 'epoch': 1.67}


  0%|          | 0/150 [00:00<?, ?it/s]

{'eval_loss': 0.999728798866272, 'eval_accuracy': 0.735, 'eval_runtime': 71.233, 'eval_samples_per_second': 16.846, 'eval_steps_per_second': 2.106, 'epoch': 2.0}
{'loss': 1.4107, 'grad_norm': 141.67831420898438, 'learning_rate': 2.505e-06, 'epoch': 2.5}


  0%|          | 0/150 [00:00<?, ?it/s]

{'eval_loss': 0.946610689163208, 'eval_accuracy': 0.745, 'eval_runtime': 37.342, 'eval_samples_per_second': 32.135, 'eval_steps_per_second': 4.017, 'epoch': 3.0}
{'loss': 1.2268, 'grad_norm': 82.1813735961914, 'learning_rate': 1.6716666666666666e-06, 'epoch': 3.33}


  0%|          | 0/150 [00:00<?, ?it/s]

{'eval_loss': 0.8811412453651428, 'eval_accuracy': 0.7625, 'eval_runtime': 36.7005, 'eval_samples_per_second': 32.697, 'eval_steps_per_second': 4.087, 'epoch': 4.0}
{'loss': 1.1091, 'grad_norm': 80.7503433227539, 'learning_rate': 8.400000000000001e-07, 'epoch': 4.17}
{'loss': 1.0241, 'grad_norm': 125.87419891357422, 'learning_rate': 6.666666666666667e-09, 'epoch': 5.0}


  0%|          | 0/150 [00:00<?, ?it/s]

{'eval_loss': 0.8599340915679932, 'eval_accuracy': 0.7691666666666667, 'eval_runtime': 46.0972, 'eval_samples_per_second': 26.032, 'eval_steps_per_second': 3.254, 'epoch': 5.0}
{'train_runtime': 2164.284, 'train_samples_per_second': 11.084, 'train_steps_per_second': 1.386, 'train_loss': 1.4860526326497396, 'epoch': 5.0}


  0%|          | 0/150 [00:00<?, ?it/s]


 Classification Report:
                               precision    recall  f1-score   support

   abstract_geometric_fabric       0.70      0.67      0.69        49
               argyle_fabric       1.00      0.97      0.99        37
              checked_fabric       0.82      0.84      0.83        58
              chevron_fabric       0.88      0.83      0.85        53
diagonal_grid_fabric_pattern       0.86      0.80      0.83        55
               floral_fabric       0.71      0.62      0.66        39
               fringe_fabric       0.91      0.76      0.83        38
         glitch_print_fabric       0.54      0.54      0.54        35
              glitter_fabric       0.54      0.48      0.51        29
             gradient_fabric       0.58      0.58      0.58        31
       graffiti_print_fabric       0.82      0.82      0.82        34
          holographic_fabric       0.65      0.89      0.75        27
          houndstooth_fabric       0.71      0.80      0.75    

In [27]:
torch.save(model.state_dict(), "clip-vit-fabric2.pth")

In [5]:
import torch
import torch.nn as nn
from transformers import CLIPVisionModel

class CLIPViTClassifier(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.clip_vit = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.clip_vit.config.hidden_size, num_classes)
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, pixel_values, labels=None):
        outputs = self.clip_vit(pixel_values=pixel_values)
        pooled_output = outputs.pooler_output
        logits = self.classifier(self.dropout(pooled_output))
        loss = None
        if labels is not None:
            loss = self.loss_fn(logits, labels)
        return (loss, logits)

from safetensors.torch import load_file

model = CLIPViTClassifier(num_classes=28)  # match the checkpoint's number of classes
state_dict = load_file("clip-vit-fabric2/checkpoint-3000/model.safetensors", device="cpu")
model.load_state_dict(state_dict)
model.eval()
# Save the model in a format compatible with the Hugging Face Transformers library
from safetensors.torch import save_file
save_file(model.state_dict(), "clip-vit-fabric2-hf")
