In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install torch torchvision pytorch-lightning

In [None]:
!unzip /content/drive/MyDrive/simclr.zip

In [None]:
import os
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader, Dataset
from PIL import Image

In [None]:
class SimCLR(pl.LightningModule):
    def __init__(self, base_encoder, out_dim):
        super(SimCLR, self).__init__()
        self.encoder = base_encoder(num_classes=out_dim)
        dim_mlp = self.encoder.fc.in_features
        self.encoder.fc = nn.Sequential(
            nn.Linear(dim_mlp, dim_mlp),
            nn.ReLU(),
            nn.Linear(dim_mlp, out_dim)
        )
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.encoder(x)

    def training_step(self, batch, batch_idx):
        (x1, x2), _ = batch
        z1, z2 = self.encoder(x1), self.encoder(x2)
        logits, labels = self.info_nce_loss(z1, z2)
        loss = self.criterion(logits, labels)
        return loss

    def info_nce_loss(self, z1, z2):
        z1 = nn.functional.normalize(z1, dim=1)
        z2 = nn.functional.normalize(z2, dim=1)
        logits = torch.mm(z1, z2.t())
        labels = torch.arange(len(logits)).long().to(logits.device)
        return logits, labels

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [None]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = [os.path.join(root_dir, fname) for fname in os.listdir(root_dir) if fname.endswith(('jpg', 'jpeg', 'png'))]

        if len(self.image_paths) == 0:
            raise ValueError(f"No images found in {root_dir}. Please check the directory and file extensions.")

        print(f"Found {len(self.image_paths)} images in {root_dir}")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        img = Image.open(img_path).convert('RGB')
        if self.transform:
            img1 = self.transform(img)
            img2 = self.transform(img)
        return (img1, img2), 0

In [None]:
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(30),
    transforms.ToTensor(),
])

dataset = CustomDataset(root_dir='/content/simclr/train', transform=transform)
dataloader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=8)

simclr = SimCLR(base_encoder=torchvision.models.resnet50, out_dim=128)
trainer = Trainer(max_epochs=100)

trainer.fit(simclr, dataloader)

torch.save(simclr.state_dict(), 'simclr_pretrained.pth')

Found 1300 images in /content/simclr/train


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name      | Type             | Params
-----------------------------------------------
0 | encoder   | ResNet           | 28.0 M
1 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
28.0 M    Trainable params
0         Non-trainable params
28.0 M    Total params
111.867   Total estimated model params size (MB)
  self.pid = os.fork()
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (6) is smaller than the logging interval Trainer(log_every

Training: |          | 0/? [00:00<?, ?it/s]

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  return F.conv2d(input, weight, bias, self.stride,
  self.pid = os.fork()
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=100` reached.


In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder

class TransferLearningModel(pl.LightningModule):
    def __init__(self, feature_extractor, num_classes):
        super(TransferLearningModel, self).__init__()
        self.feature_extractor = feature_extractor
        self.classifier = nn.Linear(feature_extractor[-1].in_features, num_classes)

    def forward(self, x):
        with torch.no_grad():
            features = self.feature_extractor(x).flatten(1)
        logits = self.classifier(features)
        return logits

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.CrossEntropyLoss()(logits, y)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

# 定义数据变换
transform = transforms.Compose([
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

# 加载训练数据
train_dataset = ImageFolder(root='/content/deep_class/train', transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=4)

# 加载预训练的SimCLR模型权重
checkpoint = torch.load('simclr_pretrained.pth')
resnet50_base = torchvision.models.resnet50(pretrained=False)
simclr_encoder = nn.Sequential(*list(resnet50_base.children())[:-1])

# 将SimCLR的权重加载到ResNet-50的前几层中
resnet50_base.load_state_dict(checkpoint, strict=False)

# 创建迁移学习模型
num_classes = len(train_dataset.classes)
model = TransferLearningModel(feature_extractor=simclr_encoder, num_classes=num_classes)

# 使用PyTorch Lightning进行训练
trainer = Trainer(max_epochs=50)
trainer.fit(model, train_dataloader)