In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import pickle

def load_cifar_batch(file):
    with open(file, 'rb') as fo:
        batch = pickle.load(fo, encoding='bytes')
    return batch

# Specify the folder where the CIFAR-10 batch files are
cifar10_dir = '/kaggle/input/deep-learning-mini-project-spring-24-nyu/cifar-10-python/cifar-10-batches-py'


# Load the batch 
validation_batch = load_cifar_batch('/kaggle/input/deep-learning-mini-project-spring-24-nyu/cifar-10-python/cifar-10-batches-py/test_batch')
# Load the label names
meta_data_dict = load_cifar_batch(os.path.join(cifar10_dir, 'batches.meta'))
label_names = meta_data_dict[b'label_names']


def load_cifar10_batches(cifar10_dir, batch_ids):
    images = []
    labels = []
    
    for batch_id in batch_ids:
        batch_path = os.path.join(cifar10_dir, f'data_batch_{batch_id}')
        with open(batch_path, 'rb') as file:
            batch = pickle.load(file, encoding='bytes')
            img = batch[b'data'].reshape((-1, 3, 32, 32)).transpose(0, 2, 3, 1)
            lb = batch[b'labels']
            plt.figure(figsize=(10, 10))  # Adjusted figure size for better visibility
            for i in range(10):  # Loop through the first 100 images
                plt.subplot(1, 10, i + 1)  # Arrange plots in 10x10 grid
                plt.imshow(img[i+10])
                plt.title(label_names[lb[i+10]].decode('utf-8'), fontsize=10)  # Ensure title is small enough to fit
                plt.axis('off')
            plt.tight_layout()  # Adjust subplots to fit in the figure area
            plt.show()
            
            images.append(batch[b'data'])
            labels.append(batch[b'labels'])
    
    images = np.concatenate(images).reshape((-1, 3, 32, 32)).transpose(0, 2, 3, 1)
    #print(type(images))
    #print(images.shape)
    #images = images.astype('float32')
    labels = np.concatenate(labels)
    
    return images, labels



train_images, train_labels = load_cifar10_batches(cifar10_dir, range(1, 6))
validation_images = validation_batch[b'data'].reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1)#.astype('float32')
validation_labels = validation_batch[b'labels']
print(train_images.shape)
print(len(train_labels))
#print(validation_images[0])
#print(train_images[0])
import numpy as np

# Concatenate the training and validation images
train_images = np.concatenate((train_images, validation_images), axis=0)

# Concatenate the training and validation labels
train_labels = np.concatenate((train_labels, validation_labels), axis=0)

In [None]:
# Read the test file, note that it has no labels and needs to be used with your model inference to predict outputs.

# Load the batch 
test_batch = load_cifar_batch('/kaggle/input/deep-learning-mini-project-spring-24-nyu/cifar_test_nolabels.pkl')

# Extract images
images = test_batch[b'data'].reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1) 



plt.figure(figsize=(20, 4))
for i in range(10):
    plt.subplot(1, 10, i+1)
    plt.imshow(images[i])
    plt.axis('off')
plt.show()

In [None]:
from torch.utils.data import Dataset, DataLoader

class ImageDataset(Dataset):
    def __init__(self, images, labels = None, transform=None,test=False,gan = False):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        
        super().__init__()
        self.images = images
        self.labels = labels
        self.transform = transform
        self.test = test
        self.gan = gan
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx];

        if self.transform:
            image = self.transform(image=image)['image']
            
        if self.gan:
            return image, 1
            
        if self.test:
            return image
        
        label = self.labels[idx]
        
        return image, label

In [None]:
import albumentations
from albumentations.pytorch.transforms import ToTensorV2

transforms_train = albumentations.Compose(
    [

        albumentations.Resize(64, 64),
        albumentations.HorizontalFlip(p=0.5),
        #albumentations.VerticalFlip(p=0.5),
        albumentations.Rotate(limit=45, p=0.5),
        albumentations.RandomBrightnessContrast(brightness_limit=(-0.5, 0.5), contrast_limit=(-0.5, 0.5), p=0.5),
        albumentations.ShiftScaleRotate(
            shift_limit=0.2, scale_limit=(-0.2,0.5), rotate_limit=0
        ),
        albumentations.Normalize(
            [0.4914, 0.4822, 0.4465],[0.247, 0.243, 0.261],
            max_pixel_value=255.0, always_apply=True
        ),
        ToTensorV2(p=1.0),
    ]
)

transforms_test = albumentations.Compose(
        [
            albumentations.Resize(64, 64),
            #albumentations.Resize(144, 144),
            albumentations.Normalize(
                [0.4914, 0.4822, 0.4465],[0.247, 0.243, 0.261],
                max_pixel_value=255.0, always_apply=True
            ),
            ToTensorV2(p=1.0)
        ]
    )

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets 
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts,ExponentialLR
# net = torchvision.models.resnet18()
# in_features = net.fc.in_features
# net.fc = nn.Linear(in_features, 10)

In [None]:
import torch
from torch import nn
from torchvision import transforms, datasets
import torch.optim as optim


class ResBlock(nn.Module):
    def __init__(self, in_features):
        super(ResBlock, self).__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.BatchNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.BatchNorm2d(in_features)
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, input_channels=3, num_residual_blocks=2):
        super(Generator, self).__init__()
        self.initial = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_channels, 16, 7),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True)
        )

        # Downsample
        self.downsample = nn.Sequential(
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 96, 3, stride=2, padding=1),
            nn.BatchNorm2d(96),
            nn.ReLU(inplace=True)
        )

        # Residual blocks
        self.resblocks = nn.Sequential(
            *[ResBlock(96) for _ in range(num_residual_blocks)]
        )

        # Upsample
        self.upsample = nn.Sequential(
            nn.ConvTranspose2d(96, 32, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True)
        )

        # Output layer
        self.output = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(16, input_channels, 7),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.initial(x)
        x = self.downsample(x)
        x = self.resblocks(x)
        x = self.upsample(x)
        x = self.output(x)
        return x


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 4, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(4, 8, 4, stride=2, padding=1),
            nn.BatchNorm2d(8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(8, 1, 4, stride=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

    # Initialize networks
G_A2B = Generator()
G_B2A = Generator()
D_A = Discriminator()
D_B = Discriminator()

# Losses
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()

# Optimizers
optimizer_G = optim.Adam(list(G_A2B.parameters()) + list(G_B2A.parameters()), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_A = optim.Adam(D_A.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_B = optim.Adam(D_B.parameters(), lr=0.0002, betas=(0.5, 0.999))


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
G_A2B = G_A2B.to(device)
G_B2A = G_B2A.to(device)
D_A = D_A.to(device)
D_B = D_B.to(device)

In [None]:
loader_A = ImageDataset(validation_images, train_labels,transform = transforms_test)

loader_B = ImageDataset(images, gan = True, transform = transforms_test )

loader_A = DataLoader(loader_A, batch_size=64, shuffle=True, drop_last=False)
loader_B = DataLoader(loader_B, batch_size=64, shuffle=True, drop_last=False)

In [None]:
import torch

# Assuming 'loader_A' and 'loader_B' are your DataLoader instances for domain A and B.
# Number of epochs
n_epochs = 40

# Lambda terms for cycle and identity loss
lambda_cycle = 5.0
lambda_identity = 1.0
from tqdm.auto import tqdm
for epoch in range(n_epochs):
    for batch_A, batch_B in tqdm(zip(loader_A, loader_B)):
        real_A = batch_A[0].to(device)
        real_B = batch_B[0].to(device)

        ###### Generators A2B and B2A ######
        optimizer_G.zero_grad()

        # Identity loss
        same_B = G_A2B(real_B)
        loss_identity_B = criterion_cycle(same_B, real_B) * lambda_identity
        same_A = G_B2A(real_A)
        loss_identity_A = criterion_cycle(same_A, real_A) * lambda_identity

        # GAN loss
        fake_B = G_A2B(real_A)
        pred_fake = D_B(fake_B)
        loss_GAN_A2B = criterion_GAN(pred_fake, torch.ones_like(pred_fake))

        fake_A = G_B2A(real_B)
        pred_fake = D_A(fake_A)
        loss_GAN_B2A = criterion_GAN(pred_fake, torch.ones_like(pred_fake))

        # Cycle loss
        recovered_A = G_B2A(fake_B)
        loss_cycle_ABA = criterion_cycle(recovered_A, real_A) * lambda_cycle

        recovered_B = G_A2B(fake_A)
        loss_cycle_BAB = criterion_cycle(recovered_B, real_B) * lambda_cycle

        # Total loss
        loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
        loss_G.backward()
        optimizer_G.step()

        ###### Discriminator A ######
        optimizer_D_A.zero_grad()

        # Real loss
        pred_real = D_A(real_A)
        loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real))

        # Fake loss
        pred_fake = D_A(fake_A.detach())
        loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake))

        # Total loss
        loss_D_A = (loss_D_real + loss_D_fake) * 0.5
        loss_D_A.backward()
        optimizer_D_A.step()

        ###### Discriminator B ######
        optimizer_D_B.zero_grad()

        # Real loss
        pred_real = D_B(real_B)
        loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real))

        # Fake loss
        pred_fake = D_B(fake_B.detach())
        loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake))

        # Total loss
        loss_D_B = (loss_D_real + loss_D_fake) * 0.5
        loss_D_B.backward()
        optimizer_D_B.step()

    print(f"Epoch {epoch}/{n_epochs} finished")
    print(f"Loss G: {loss_G.item():.4f}, "
          f"Loss D_X: {loss_D_A.item():.4f}, Loss D_Y: {loss_D_B.item():.4f}")


In [None]:
import matplotlib.pyplot as plt
import torch

def visualize_cycleGAN_results(model_G_A2B, model_G_B2A, test_loader_A, test_loader_B):
    # Switch models to evaluation mode
    model_G_A2B.eval()
    model_G_B2A.eval()
    count = 0
    with torch.no_grad(): # No need to track gradients
        for batchA,batchB in zip(test_loader_A,test_loader_B):
            real_A, _ = batchA
            real_B, _ = batchB
            real_A = real_A.to(device)
            real_B = real_B.to(device)
            count+=1;
            
            # Generate images from domain A to B and back to A
            fake_A = model_G_A2B(real_A)
            fake_B = model_G_B2A(real_B)
            
            # Move images back to CPU for visualization
            real_A = real_A.cpu()
            real_B = real_B.cpu()
            fake_A = fake_A.cpu()
            fake_B = fake_B.cpu()

            # Plot the original, transformed, and reconstructed images
            plt.figure(figsize=(6, 3))
            
            # Display real image
            plt.subplot(2, 2, 1)
            plt.title("Original Image A")
            plt.imshow(real_A[0].permute(1, 2, 0).numpy() * 0.5 + 0.5) # Unnormalize
            plt.axis('off')
            
             # Display real image
            plt.subplot(2, 2, 2)
            plt.title("Fake A")
            plt.imshow(fake_A[0].permute(1, 2, 0).numpy() * 0.5 + 0.5) # Unnormalize
            plt.axis('off')
            
            # Display real image
            plt.subplot(2, 2, 3)
            plt.title("Original Image B")
            plt.imshow(real_B[0].permute(1, 2, 0).numpy() * 0.5 + 0.5) # Unnormalize
            plt.axis('off')
            
             # Display real image
            plt.subplot(2, 2, 4)
            plt.title("Fake B")
            plt.imshow(fake_B[0].permute(1, 2, 0).numpy() * 0.5 + 0.5) # Unnormalize
            plt.axis('off')

            plt.show()
            if count == 10:
                break # Just show one set of images for example

# Assuming model_G_A2B, model_G_B2A, and test_loader_A are defined and loaded
visualize_cycleGAN_results(G_A2B, G_B2A, loader_A,loader_B)


In [None]:
G_A2B.eval()
G_B2A.eval()

In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Root(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1):
        super(Root, self).__init__()
        self.conv = nn.Conv2d(
            in_channels, out_channels, kernel_size,
            stride=1, padding=(kernel_size - 1) // 2, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, xs):
        x = torch.cat(xs, 1)
        out = F.relu(self.bn(self.conv(x)))
        return out


class Tree(nn.Module):
    def __init__(self, block, in_channels, out_channels, level=1, stride=1):
        super(Tree, self).__init__()
        self.level = level
        if level == 1:
            self.root = Root(2*out_channels, out_channels)
            self.left_node = block(in_channels, out_channels, stride=stride)
            self.right_node = block(out_channels, out_channels, stride=1)
        else:
            self.root = Root((level+2)*out_channels, out_channels)
            for i in reversed(range(1, level)):
                subtree = Tree(block, in_channels, out_channels,
                               level=i, stride=stride)
                self.__setattr__('level_%d' % i, subtree)
            self.prev_root = block(in_channels, out_channels, stride=stride)
            self.left_node = block(out_channels, out_channels, stride=1)
            self.right_node = block(out_channels, out_channels, stride=1)

    def forward(self, x):
        xs = [self.prev_root(x)] if self.level > 1 else []
        for i in reversed(range(1, self.level)):
            level_i = self.__getattr__('level_%d' % i)
            x = level_i(x)
            xs.append(x)
        x = self.left_node(x)
        xs.append(x)
        x = self.right_node(x)
        xs.append(x)
        out = self.root(xs)
        return out


class DLA(nn.Module):
    def __init__(self, block=BasicBlock, num_classes=10):
        super(DLA, self).__init__()
        self.base = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(True)
        )

        self.layer1 = nn.Sequential(
            nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(True)
        )

        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(True)
        )
        
        

        self.layer3 = Tree(block,  32,  64, level=1, stride=2)
        self.layer4 = Tree(block,  64, 128, level=2, stride=2)
        self.layer5 = Tree(block, 128, 256, level=1, stride=2)
#         self.layer6 = nn.Sequential(
#             nn.Conv2d(256, 512, kernel_size=2, stride=2, bias=False),
#             nn.BatchNorm2d(512),
#             nn.ReLU(True)
#         )
        #self.layer6 = Tree(block, 256, 512, level=1, stride=2)
        self.linear = nn.Linear(1024, num_classes)

    def forward(self, x):
        out = self.base(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        #out = self.layer6(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out





net = DLA()

In [None]:
import timm
parent = timm.create_model('dla102', pretrained=False,num_classes=10)

In [None]:
# !pip install torch-summary

In [None]:
# from torchsummary import summary
# summary(model, (3, 64, 64))

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net  = net.to(device)
parent = parent.to(device)

In [None]:
from tqdm.auto import tqdm

def train_model(train_loader):
    losses = []
    optimizer = optim.Adam(net.parameters(), lr=0.0002,weight_decay=1e-5)
    loss = nn.CrossEntropyLoss(reduction='mean')
#     scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min = 1e-6)
    scheduler = ExponentialLR(optimizer, gamma=0.975,verbose=True)
    for i in range(epoch):
        acc = 0
        loss_sum = 0
        net.train()
        for x, y in tqdm(train_loader):
            x = x.to(device)
            #print(x.shape)
            num = torch.rand(1).item()
            if num>=0 and num<0.3:
                x = G_A2B(x)
            x = torch.as_tensor(x, dtype=torch.float)
            y = y.to(device)
            #print(y.shape)
            #print(y)
            y_hat = net(x)
            #print(y_hat.shape)
            #print(y_hat)
            #print(y_hat.argmax(dim=1).type(y.dtype))
            loss_temp = loss(y_hat, y)
            loss_sum += loss_temp
            optimizer.zero_grad()
            loss_temp.backward()
            optimizer.step()
#             scheduler.step()
            acc += torch.sum(y_hat.argmax(dim=1).type(y.dtype) == y)
        scheduler.step()
        losses.append(loss_sum.cpu().detach().numpy() / len(train_loader))
        print( "epoch: ", i, "loss=", loss_sum.item(), "训练集准确度=",(acc/(len(train_loader)*train_loader.batch_size)).item(),end="")



In [None]:
from tqdm.auto import tqdm

def train_parent(train_loader):
    losses = []
    optimizer = optim.Adam(parent.parameters(), lr=0.00007,weight_decay=1e-5)
    loss = nn.CrossEntropyLoss(reduction='mean')
#     scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min = 1e-6)
    scheduler = ExponentialLR(optimizer, gamma=0.975,verbose=True)
    for i in range(epoch):
        acc = 0
        loss_sum = 0
        parent.train()
        for x, y in tqdm(train_loader):
            x = x.to(device)
            #print(x.shape)
            num = torch.rand(1).item()
            if num>=0 and num<0.3:
                x = G_A2B(x)
            x = torch.as_tensor(x, dtype=torch.float)
            y = y.to(device)
            #print(y.shape)
            #print(y)
            y_hat = parent(x)
            #print(y_hat.shape)
            #print(y_hat)
            #print(y_hat.argmax(dim=1).type(y.dtype))
            loss_temp = loss(y_hat, y)
            loss_sum += loss_temp
            optimizer.zero_grad()
            loss_temp.backward()
            optimizer.step()
#             scheduler.step()
            acc += torch.sum(y_hat.argmax(dim=1).type(y.dtype) == y)
        scheduler.step()
        losses.append(loss_sum.cpu().detach().numpy() / len(train_loader))
        print( "epoch: ", i, "loss=", loss_sum.item(), "训练集准确度=",(acc/(len(train_loader)*train_loader.batch_size)).item(),end="")


In [None]:
# class DistillationLoss(nn.Module):
#     def __init__(self, temperature=1):
#         super(DistillationLoss, self).__init__()
#         self.temperature = temperature

#     def forward(self, outputs_student, outputs_teacher):
#         soft_targets = nn.functional.softmax(outputs_teacher / self.temperature, dim=1)
#         log_probs = nn.functional.log_softmax(outputs_student / self.temperature, dim=1)
#         return nn.KLDivLoss(reduction='batchmean')(log_probs, soft_targets)


import torch
import torch.nn as nn
import torch.nn.functional as F

class DistillationLoss(nn.Module):
    def __init__(self, temperature=1.0, alpha=0.5):
        """
        :param temperature: Temperature parameter to soften probability distributions.
        :param alpha: Weighting factor for the distillation loss vs. the standard loss.
        """
        super(DistillationLoss, self).__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.kl_div = nn.KLDivLoss(reduction='batchmean')

    def forward(self, outputs_student, outputs_teacher, labels=None):
        """
        Compute the distillation loss between the student and teacher outputs.
        If labels are provided, also compute the standard loss and return a weighted sum.
        """
        soft_targets = F.softmax(outputs_teacher / self.temperature, dim=1)
        log_probs = F.log_softmax(outputs_student / self.temperature, dim=1)
        distillation_loss = self.kl_div(log_probs, soft_targets)
        
        if labels is not None:
            standard_loss = F.cross_entropy(outputs_student, labels)
            return self.alpha * standard_loss + (1 - self.alpha) * distillation_loss
        else:
            return distillation_loss


In [None]:
from tqdm.auto import tqdm

def train_model_with_teacher(train_loader):
    losses = []
    optimizer = optim.Adam(net.parameters(), lr=0.0006,weight_decay=1e-5)
    loss = nn.CrossEntropyLoss(reduction='mean')
    distillation_loss = DistillationLoss(temperature=3) 
#     scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min = 1e-6)
    scheduler = ExponentialLR(optimizer, gamma=0.975,verbose=True)
    for i in range(epoch):
        acc = 0
        loss_sum = 0
        net.train()
        for x, y in tqdm(train_loader):
            x = x.to(device)
            #print(x.shape)
            num = torch.rand(1).item()
            if num>=0 and num<0.3:
                x = G_A2B(x)
            x = torch.as_tensor(x, dtype=torch.float)
            y = y.to(device)  
            y_hat = net(x)
            
            outputs_teacher = parent(x)
            outputs_student = net(x)
            
            loss_temp = loss(outputs_student, y) + distillation_loss(outputs_student, outputs_teacher)
            
            loss_sum += loss_temp
            optimizer.zero_grad()
            loss_temp.backward()
            optimizer.step()
#             scheduler.step()
            acc += torch.sum(y_hat.argmax(dim=1).type(y.dtype) == y)
        scheduler.step()
        losses.append(loss_sum.cpu().detach().numpy() / len(train_loader))
        print( "epoch: ", i, "loss=", loss_sum.item(), "训练集准确度=",(acc/(len(train_loader)*train_loader.batch_size)).item(),end="")



In [None]:
from tqdm.auto import tqdm

def teach_student(train_loader):
    losses = []
    optimizer = optim.Adam(net.parameters(), lr=0.0003,weight_decay=1e-5)
    teacher_optimizer = optim.Adam(parent.parameters(), lr=0.00003,weight_decay=1e-5)
    student_optimizer = optim.Adam(net.parameters(), lr=0.001,weight_decay=1e-5)
    loss = nn.CrossEntropyLoss(reduction='mean')
    distillation_loss = DistillationLoss(temperature=3) 
    scheduler = ExponentialLR(teacher_optimizer, gamma=0.99,verbose=True)
    scheduler = ExponentialLR(student_optimizer, gamma=0.98,verbose=True)
    for i in range(epoch):
        acc = 0
        acc_t = 0
        loss_sum = 0
        net.train()
        parent.train()
        for x, y in tqdm(train_loader):
            teacher_optimizer.zero_grad()
            student_optimizer.zero_grad()
            x = x.to(device)
            #print(x.shape)
            x = torch.as_tensor(x, dtype=torch.float)
            y = y.to(device)
            
            outputs_teacher = parent(x)
            outputs_student = net(x)
            
            loss_teacher = loss(outputs_teacher, y)
            loss_student = loss(outputs_student, y) + distillation_loss(outputs_student, outputs_teacher)
#             print("loss")
#             print(loss(outputs_student, y))
#             print(distillation_loss(outputs_student, outputs_teacher))
#             print(outputs_teacher)
#             print(outputs_student)
            
            loss_sum += loss_teacher

            loss_teacher.backward(retain_graph=True)
#             loss_teacher.backward()
            loss_student.backward()
            teacher_optimizer.step()
            student_optimizer.step()
            
            acc += torch.sum(outputs_student.argmax(dim=1).type(y.dtype) == y)
            acc_t += torch.sum(outputs_teacher.argmax(dim=1).type(y.dtype) == y)
        scheduler.step()
        losses.append(loss_sum.cpu().detach().numpy() / len(train_loader))
        print( "epoch: ", i, "loss=", loss_sum.item(), "训练集准确度=",(acc/(len(train_loader)*train_loader.batch_size)).item(),end="")
        print( "epoch: ", i, "loss=", loss_sum.item(), "teacher准确度=",(acc_t/(len(train_loader)*train_loader.batch_size)).item(),end="")


In [None]:
#from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts,ExponentialLR

print(train_images.shape)
print(train_labels.shape)
print(validation_images.shape)
validation_labels = np.array(validation_labels)
print(validation_labels.shape)

#train_dataset = ImageDataset(train_images, train_labels, transform=transforms_train)
train_dataset = ImageDataset(train_images, train_labels, transform=transforms_train)
validation_dataset = ImageDataset(validation_images, validation_labels, transform=transforms_test)
#submit_ds = ImageDataset(csv_file='/kaggle/input/classify-leaves/test.csv', root_dir='../input/classify-leaves', transform=transforms_test, label_return=False)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=False)
validation_loader = DataLoader(validation_dataset, batch_size=64, shuffle=False, drop_last=False)
#submit_loader = DataLoader(submit_ds, batch_size=32, shuffle=False, drop_last=False)

In [None]:
epoch = 40
train_parent(train_loader)

In [None]:
epoch = 80
teach_student(train_loader)
#net = parent
#train_model(train_loader,validation_loader)

In [None]:
epoch = 70
train_model_with_teacher(train_loader)

In [None]:
import pandas as pd
class ImageDatasetWithoutLabels(Dataset):
    """A dataset class for images without labels."""
    def __init__(self, images, transform=None):
        self.images = images
        self.transform = transform

        
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        if self.transform:
            image = self.transform(image=image)['image']
        return image
    
submit_dataset = ImageDatasetWithoutLabels(images, transform=transforms_test)
submit_loader = DataLoader(submit_dataset, batch_size=64, shuffle=False, drop_last=False)
net.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.to(device)

predictions = []

# Disable gradient computation for evaluation
with torch.no_grad():
    for inputs in submit_loader:
        inputs = inputs.to(device)  # Move the inputs to the same device as the model
        outputs = net(inputs)
        _, predicted = torch.max(outputs, 1)
        predictions.extend(predicted.cpu().numpy())
        
# Assuming the ID is just the index of the image
submission_df = pd.DataFrame({
    'ID': np.arange(len(predictions)),
    'Labels': predictions
})

# Write the submission file
submission_df.to_csv('/kaggle/working/submission.csv', index=False)

In [None]:
#submission_df