In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, datasets, transforms
from torchvision.models.resnet import ResNet18_Weights
from torch.utils.data import Dataset, DataLoader,Subset,ConcatDataset
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import numpy as np
import os
import random
from PIL import Image

torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 数据预处理和加载
transform = transforms.Compose(transforms=[
    transforms.Resize((64, 64)),
    transforms.Resize((224, 224)),
    transforms.transforms.RandomHorizontalFlip(p=0.2), 
    transforms.RandomRotation(degrees=20),  
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
test_transform = transforms.Compose(transforms=[
    transforms.Resize((64, 64)),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 定义 Flower102 合成数据集
class Flower102SyntheticDataset(Dataset):
    def __init__(self, npz_file, transform=transform):
        # 加载 npz 文件
        data = np.load(npz_file, allow_pickle=True)
        self.images = data['arr_0']  # Flower 图片
        self.labels = data['arr_1']  # 对应的标签
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.fromarray(self.images[idx].astype('uint8'))  # 转换为 PIL 图像
        label = self.labels[idx]  # 取对应的标签
        if self.transform:
            image = self.transform(image)
        return image, label

class DuplicatedDataset(Dataset):
    def __init__(self, original_dataset, num_copies=2):
        self.original_dataset = original_dataset
        self.num_copies = num_copies

    def __len__(self):
        return len(self.original_dataset) * self.num_copies

    def __getitem__(self, idx):
        return self.original_dataset[idx % len(self.original_dataset)]


npz_file="./synthetic_data/UNet_flowers-250-sampling_steps-50000_images-class_condn_True.npz"

def load_sampled_synthetic_data(npz_file, original_size=1020, batch_size=64,scale=2):
    dataset = Flower102SyntheticDataset(npz_file)
    sample_size = original_size*scale
    # 从合成数据集中随机采样
    sampled_indices = random.sample(range(len(dataset)), sample_size)
    sampled_dataset = Subset(dataset, sampled_indices)

    # 使用 DataLoader 加载数据
    return sampled_dataset

# trainset
scale_size= 12
trainset = datasets.Flowers102(root='./data', split='train', download=True, transform=transform)

syn_trainset = load_sampled_synthetic_data(npz_file, scale=scale_size-1)
combined_dataset = ConcatDataset([trainset, syn_trainset])
print(f"syn_data_size:{len(combined_dataset)}")
trainset_2 = DuplicatedDataset(trainset, num_copies=scale_size)
print(f"train_data_size:{len(trainset_2)}")

# valset and testset
valset = datasets.Flowers102(root='./data', split='val', download=True, transform=transform)
testset = datasets.Flowers102(root='./data', split='test', download=True, transform=test_transform)

# data loader
num_workers=4
batch_size=64
trainloader_syn= DataLoader(combined_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
trainloader_2 = DataLoader(trainset_2, batch_size=batch_size, shuffle=True, num_workers=num_workers)
valloader = DataLoader(valset, batch_size=128, shuffle=False, num_workers=num_workers)
testloader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=num_workers)


# 创建保存模型的目录
os.makedirs('checkpoints', exist_ok=True)

dataset_name = 'flowers102'
log_dir = f'/root/tf-logs/{dataset_name}'
writer = SummaryWriter(log_dir=log_dir)
# 训练函数
def train_model(model, criterion, optimizer, scheduler, trainloader, valloader, device,writer=writer, num_epochs=5):
    global best_val_loss, patience_counter
    model.train()
    train_losses = []
    val_losses = []
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        batch_iterator = tqdm(trainloader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")
        for inputs, labels in batch_iterator:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            batch_iterator.set_postfix(loss=loss.item())
        epoch_loss = running_loss / len(trainloader)
        train_accuracy = 100 * correct / total
        train_losses.append(epoch_loss)
        writer.add_scalar('Loss/train', epoch_loss, epoch)
        writer.add_scalar('Accuracy/train', train_accuracy, epoch)
        writer.add_scalar('Learning_Rate', scheduler.get_last_lr()[0], epoch)
        print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f} Accuracy: {train_accuracy:.2f}%")

        # 验证模型
        val_loss, val_accuracy = evaluate_model(model, criterion, valloader, device,writer, epoch)
        val_losses.append(val_loss)


    return train_losses, val_losses

# 评估函数
def evaluate_model(model, criterion, dataloader, device, writer,epoch):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc=">>> Evaluating", unit="batch"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    avg_loss = running_loss / len(dataloader)
    accuracy = 100 * correct / total
    writer.add_scalar('Loss/val' if dataloader == valloader else 'Loss/test', avg_loss, epoch)
    writer.add_scalar('Accuracy/val' if dataloader == valloader else 'Accuracy/test', accuracy, epoch)
    print(f"    Accuracy: {accuracy:.2f}% Avg_Loss: {avg_loss:.4f}")
    return avg_loss, accuracy


syn_data_size:12240
train_data_size:12240


In [2]:
num_epochs = 10

# 初始化模型 预训练模型
model = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 102)
model = model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)

# 定义 OneCycleLR 学习率调度器

scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, steps_per_epoch=len(trainloader_2), epochs=num_epochs)

# 训练和评估模型
train_model(model, criterion, optimizer, scheduler, trainloader_2, valloader, device, writer,num_epochs)

# 关闭 TensorBoard
writer.close()

print(f"--->")
evaluate_model(model, criterion, testloader, device, writer,0)

Epoch 1/10: 100%|██████████| 192/192 [00:13<00:00, 14.57batch/s, loss=4.23]


Epoch 1 Loss: 4.6001 Accuracy: 3.73%


>>> Evaluating: 100%|██████████| 8/8 [00:01<00:00,  5.36batch/s]


    Accuracy: 10.29% Avg_Loss: 4.3047


Epoch 2/10: 100%|██████████| 192/192 [00:13<00:00, 14.70batch/s, loss=0.423]


Epoch 2 Loss: 1.6525 Accuracy: 61.14%


>>> Evaluating: 100%|██████████| 8/8 [00:01<00:00,  5.47batch/s]


    Accuracy: 58.43% Avg_Loss: 1.8262


Epoch 3/10: 100%|██████████| 192/192 [00:13<00:00, 14.57batch/s, loss=0.00363]


Epoch 3 Loss: 0.0815 Accuracy: 97.94%


>>> Evaluating: 100%|██████████| 8/8 [00:01<00:00,  5.59batch/s]


    Accuracy: 72.16% Avg_Loss: 1.2544


Epoch 4/10: 100%|██████████| 192/192 [00:13<00:00, 14.76batch/s, loss=0.00792]


Epoch 4 Loss: 0.0174 Accuracy: 99.56%


>>> Evaluating: 100%|██████████| 8/8 [00:01<00:00,  5.31batch/s]


    Accuracy: 74.51% Avg_Loss: 1.1873


Epoch 5/10: 100%|██████████| 192/192 [00:13<00:00, 14.70batch/s, loss=0.000428]


Epoch 5 Loss: 0.0016 Accuracy: 100.00%


>>> Evaluating: 100%|██████████| 8/8 [00:01<00:00,  5.59batch/s]


    Accuracy: 75.88% Avg_Loss: 1.2024


Epoch 6/10: 100%|██████████| 192/192 [00:13<00:00, 14.56batch/s, loss=0.00242] 


Epoch 6 Loss: 0.0010 Accuracy: 99.98%


>>> Evaluating: 100%|██████████| 8/8 [00:01<00:00,  5.39batch/s]


    Accuracy: 75.78% Avg_Loss: 1.2987


Epoch 7/10: 100%|██████████| 192/192 [00:13<00:00, 14.30batch/s, loss=0.000383]


Epoch 7 Loss: 0.0005 Accuracy: 100.00%


>>> Evaluating: 100%|██████████| 8/8 [00:01<00:00,  5.48batch/s]


    Accuracy: 75.88% Avg_Loss: 1.2530


Epoch 8/10: 100%|██████████| 192/192 [00:13<00:00, 14.63batch/s, loss=0.00019] 


Epoch 8 Loss: 0.0003 Accuracy: 100.00%


>>> Evaluating: 100%|██████████| 8/8 [00:01<00:00,  5.56batch/s]


    Accuracy: 76.18% Avg_Loss: 1.2941


Epoch 9/10: 100%|██████████| 192/192 [00:13<00:00, 14.58batch/s, loss=0.000193]


Epoch 9 Loss: 0.0003 Accuracy: 100.00%


>>> Evaluating: 100%|██████████| 8/8 [00:01<00:00,  5.49batch/s]


    Accuracy: 75.69% Avg_Loss: 1.2873


Epoch 10/10: 100%|██████████| 192/192 [00:13<00:00, 14.76batch/s, loss=0.000475]


Epoch 10 Loss: 0.0003 Accuracy: 100.00%


>>> Evaluating: 100%|██████████| 8/8 [00:01<00:00,  5.45batch/s]


    Accuracy: 74.31% Avg_Loss: 1.3247
--->


>>> Evaluating: 100%|██████████| 49/49 [00:06<00:00,  7.57batch/s]


    Accuracy: 73.05% Avg_Loss: 1.5811


(1.5810608550221945, 73.05252886648235)

In [3]:

num_epochs = 10
dataset_name = 'flowers102_synthetic'
log_dir = f'/root/tf-logs/{dataset_name}'
writer = SummaryWriter(log_dir=log_dir)

def train_model(model, criterion, optimizer, scheduler, trainloader, valloader, device,writer, num_epochs=5):

    model.train()
    train_losses = []
    val_losses = []
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        batch_iterator = tqdm(trainloader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")
        for inputs, labels in batch_iterator:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            batch_iterator.set_postfix(loss=loss.item())
        epoch_loss = running_loss / len(trainloader)
        train_accuracy = 100 * correct / total
        train_losses.append(epoch_loss)
        writer.add_scalar('Loss/train', epoch_loss, epoch)
        writer.add_scalar('Accuracy/train', train_accuracy, epoch)
        writer.add_scalar('Learning_Rate', scheduler.get_last_lr()[0], epoch)
        print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f} Accuracy: {train_accuracy:.2f}%")

        # 验证模型
        val_loss, val_accuracy = evaluate_model(model, criterion, valloader, device,writer, epoch)
        val_losses.append(val_loss)


    return train_losses, val_losses

# 评估函数
def evaluate_model(model, criterion, dataloader, device, writer,epoch):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc=">>> Evaluating", unit="batch"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    avg_loss = running_loss / len(dataloader)
    accuracy = 100 * correct / total
    writer.add_scalar('Loss/val' if dataloader == valloader else 'Loss/test', avg_loss, epoch)
    writer.add_scalar('Accuracy/val' if dataloader == valloader else 'Accuracy/test', accuracy, epoch)
    print(f"    Accuracy: {accuracy:.2f}% Avg_Loss: {avg_loss:.4f}")
    return avg_loss, accuracy

# 初始化模型 预训练模型
model_syn = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
num_ftrs = model_syn.fc.in_features
model_syn.fc = nn.Linear(num_ftrs, 102)
model_syn = model_syn.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer_syn = optim.SGD(model_syn.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)

# 定义 OneCycleLR 学习率调度器

scheduler_syn = torch.optim.lr_scheduler.OneCycleLR(optimizer_syn, max_lr=0.001, steps_per_epoch=len(trainloader_syn), epochs=num_epochs)

# 训练和评估模型
train_model(model_syn, criterion, optimizer_syn, scheduler_syn, trainloader_syn, valloader, device, writer,num_epochs)

# 关闭 TensorBoard
writer.close()

print(f"--->")
evaluate_model(model_syn, criterion, testloader, device, writer,0)

Epoch 1/10: 100%|██████████| 192/192 [00:08<00:00, 22.63batch/s, loss=4.56]


Epoch 1 Loss: 4.6923 Accuracy: 1.54%


>>> Evaluating: 100%|██████████| 8/8 [00:01<00:00,  5.22batch/s]


    Accuracy: 6.08% Avg_Loss: 4.4504


Epoch 2/10: 100%|██████████| 192/192 [00:08<00:00, 23.28batch/s, loss=1.78]


Epoch 2 Loss: 3.3611 Accuracy: 23.29%


>>> Evaluating: 100%|██████████| 8/8 [00:01<00:00,  5.38batch/s]


    Accuracy: 56.27% Avg_Loss: 1.6783


Epoch 3/10: 100%|██████████| 192/192 [00:08<00:00, 23.05batch/s, loss=1.48] 


Epoch 3 Loss: 1.5853 Accuracy: 59.51%


>>> Evaluating: 100%|██████████| 8/8 [00:01<00:00,  5.30batch/s]


    Accuracy: 68.43% Avg_Loss: 1.1476


Epoch 4/10: 100%|██████████| 192/192 [00:08<00:00, 23.05batch/s, loss=0.995]


Epoch 4 Loss: 1.0601 Accuracy: 71.84%


>>> Evaluating: 100%|██████████| 8/8 [00:01<00:00,  5.39batch/s]


    Accuracy: 78.04% Avg_Loss: 0.7656


Epoch 5/10: 100%|██████████| 192/192 [00:08<00:00, 22.61batch/s, loss=0.666]


Epoch 5 Loss: 0.7693 Accuracy: 78.43%


>>> Evaluating: 100%|██████████| 8/8 [00:01<00:00,  5.24batch/s]


    Accuracy: 86.08% Avg_Loss: 0.4943


Epoch 6/10: 100%|██████████| 192/192 [00:08<00:00, 22.94batch/s, loss=0.0727]


Epoch 6 Loss: 0.5484 Accuracy: 84.63%


>>> Evaluating: 100%|██████████| 8/8 [00:01<00:00,  5.42batch/s]


    Accuracy: 84.41% Avg_Loss: 0.5639


Epoch 7/10: 100%|██████████| 192/192 [00:08<00:00, 23.11batch/s, loss=0.426]


Epoch 7 Loss: 0.4119 Accuracy: 88.03%


>>> Evaluating: 100%|██████████| 8/8 [00:01<00:00,  5.23batch/s]


    Accuracy: 85.10% Avg_Loss: 0.5720


Epoch 8/10: 100%|██████████| 192/192 [00:08<00:00, 22.83batch/s, loss=0.337] 


Epoch 8 Loss: 0.2787 Accuracy: 91.81%


>>> Evaluating: 100%|██████████| 8/8 [00:01<00:00,  5.35batch/s]


    Accuracy: 87.45% Avg_Loss: 0.4414


Epoch 9/10: 100%|██████████| 192/192 [00:08<00:00, 23.08batch/s, loss=0.294] 


Epoch 9 Loss: 0.1711 Accuracy: 95.18%


>>> Evaluating: 100%|██████████| 8/8 [00:01<00:00,  5.32batch/s]


    Accuracy: 88.82% Avg_Loss: 0.4093


Epoch 10/10: 100%|██████████| 192/192 [00:08<00:00, 23.12batch/s, loss=0.069] 


Epoch 10 Loss: 0.1049 Accuracy: 97.20%


>>> Evaluating: 100%|██████████| 8/8 [00:01<00:00,  5.33batch/s]


    Accuracy: 88.92% Avg_Loss: 0.3969
--->


>>> Evaluating: 100%|██████████| 49/49 [00:06<00:00,  7.49batch/s]

    Accuracy: 77.70% Avg_Loss: 1.1579





(1.1578952986366895, 77.70369165718003)

In [4]:
# import umap
# import matplotlib.pyplot as plt
# from sklearn.preprocessing import StandardScaler

# def plot_umap_with_labels(features, labels, class_names, n_neighbors=15, min_dist=0.1, n_components=2):
#     features = StandardScaler().fit_transform(features)
#     reducer = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, n_components=n_components)
#     embedding = reducer.fit_transform(features)
    
#     plt.figure(figsize=(20, 16))  # 增大图像尺寸以适应更多标签
#     scatter = plt.scatter(embedding[:, 0], embedding[:, 1], c=labels, cmap='Spectral', s=10)
    
#     # 使用文本而不是颜色条来表示类别
#     for i, class_name in enumerate(class_names):
#         idx = labels == i
#         if np.any(idx):
#             centroid = np.mean(embedding[idx], axis=0)
#             plt.annotate(class_name, centroid, fontsize=8, alpha=0.7)
    
#     plt.title("UMAP projection of Flowers102 dataset")
#     plt.xlabel('UMAP Dimension 1')
#     plt.ylabel('UMAP Dimension 2')
#     plt.tight_layout()
#     plt.show()

# def extract_features(model, dataloader, device):
#     model.eval()
#     features = []
#     labels = []
#     with torch.no_grad():
#         for inputs, targets in tqdm(dataloader, desc="Extracting Features"):
#             inputs = inputs.to(device)
#             outputs = model(inputs)
#             features.append(outputs.cpu().numpy())
#             labels.append(targets.numpy())
#     return np.vstack(features), np.concatenate(labels)

# # 运行特征提取
# features, labels = extract_features(model_syn, testloader, device)

# # 加载花名（假设您已经下载了包含花名的文本文件）
# class_names = class_names = [
#     'pink primrose', 'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea',
#     'english marigold', 'tiger lily', 'moon orchid', 'bird of paradise', 'monkshood',
#     'globe thistle', 'snapdragon', "colt's foot", 'king protea', 'spear thistle',
#     'yellow iris', 'globe-flower', 'purple coneflower', 'peruvian lily', 'balloon flower',
#     'giant white arum lily', 'fire lily', 'pincushion flower', 'fritillary',
#     'red ginger', 'grape hyacinth', 'corn poppy', 'prince of wales feathers',
#     'stemless gentian', 'artichoke', 'sweet william', 'carnation',
#     'garden phlox', 'love in the mist', 'mexican aster', 'alpine sea holly',
#     'ruby-lipped cattleya', 'cape flower', 'great masterwort', 'siam tulip',
#     'lenten rose', 'barbeton daisy', 'daffodil', 'sword lily', 'poinsettia',
#     'bolero deep blue', 'wallflower', 'marigold', 'buttercup', 'oxeye daisy',
#     'common dandelion', 'petunia', 'wild pansy', 'primula', 'sunflower',
#     'pelargonium', 'bishop of llandaff', 'gaura', 'geranium', 'orange dahlia',
#     'pink-yellow dahlia?', 'cautleya spicata', 'japanese anemone', 'black-eyed susan',
#     'silverbush', 'californian poppy', 'osteospermum', 'spring crocus',
#     'bearded iris', 'windflower', 'tree poppy', 'gazania', 'azalea', 'water lily',
#     'rose', 'thorn apple', 'morning glory', 'passion flower', 'lotus', 'toad lily',
#     'anthurium', 'frangipani', 'clematis', 'hibiscus', 'columbine', 'desert-rose',
#     'tree mallow', 'magnolia', 'cyclamen', 'watercress', 'canna lily', 'hippeastrum',
#     'bee balm', 'ball moss', 'foxglove', 'bougainvillea', 'camellia', 'mallow',
#     'mexican petunia', 'bromelia', 'blanket flower', 'trumpet creeper',
#     'blackberry lily'
# ]

# # 运行 UMAP 并显示
# plot_umap_with_labels(features, labels, class_names)