In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import os

In [2]:
dataset = datasets.ImageFolder(
    'dataset',
    transforms.Compose([
        transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
)

- !pip install scikit-learn

In [3]:
from sklearn.model_selection import KFold
import torch.utils.data

k = 12
kf = KFold(n_splits=k, shuffle=True)

train_loaders = []
test_loaders = []


for train_indices, test_indices in kf.split(dataset):
    train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
    test_sampler = torch.utils.data.SubsetRandomSampler(test_indices)
    
    train_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=8,
        sampler=train_sampler,
        num_workers=0
    )
    
    test_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=8,
        sampler=test_sampler,
        num_workers=0
    )
    
    train_loaders.append(train_loader)
    test_loaders.append(test_loader)


In [4]:
#hold-out
test_percent = 0.2
num_test = int(test_percent * len(dataset))
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len(dataset) - num_test, num_test])

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=0
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=0
)

In [None]:
from rdnet18_ca import rdnet18_ca

model_path = "resnet18-5c106cde.pth"
pretrained_dict = torch.load(model_path)

model = rdnet18_ca()

model_dict = model.state_dict()

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)


In [6]:
model.fc = torch.nn.Linear(512, 2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [7]:
def accuracy(predictions, labels):
    pred = torch.max(predictions.data, 1)[1]
    rights = pred.eq(labels.data.view_as(pred)).sum()
    return rights, len(labels)

In [None]:
import matplotlib.pyplot as plt
num_epochs = 30
BEST_MODEL_PATH = 'best_model_classification.pth'
criterion = nn.CrossEntropyLoss()
best_r = 0.0
optimizer = optim.Adam(model.parameters(), lr=0.001)

train_losses = []
val_losses = []
train_accs = []
val_accs = []

for epoch in range(num_epochs):
    train_rights = []

    for batch_idx, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        model.train()
        outputs = model(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        right = accuracy(outputs, labels)
        train_rights.append(right)

        if batch_idx % 100 == 0:
            model.eval()
            val_rights = []

            for (images, labels) in test_loader:
                images = images.to(device)
                labels = labels.to(device)
                outputs = model(images)
                right = accuracy(outputs, labels)
                val_rights.append(right)


            train_r = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights]))
            val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights]))
            train_acc = 100. * train_r[0].cpu().numpy() / train_r[1]
            val_acc = 100. * val_r[0].cpu().numpy() / val_r[1]

            print('当前epoch: {} [{}/{} ({:.0f}%)]\t损失: {:.6f}\t训练集准确率: {:.2f}%\t测试集正确率: {:.2f}%'.format(
                epoch, batch_idx * train_loader.batch_size, len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.data,
                train_acc, val_acc))

            train_losses.append(loss.item())
            val_losses.append(criterion(outputs, labels).item())
            train_accs.append(train_acc)
            val_accs.append(val_acc)

    if val_acc > best_r:
        torch.save(model.state_dict(), BEST_MODEL_PATH)
        best_r = val_acc
        print("Saved best model with validation accuracy: {:.2f}%".format(val_acc))

plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.legend()
plt.show()

plt.plot(train_accs, label='Train Accuracy')
plt.plot(val_accs, label='Validation Accuracy')
plt.legend()
