In [None]:
# only use one GPU
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import setpath
import numpy as np
import torch
from torchvision.datasets import VisionDataset
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.optim as optim
import torch.nn as nn
from run.start import get_device
from torch.utils.tensorboard import SummaryWriter
from utils.eegutils import get_simple_log_dir

In [None]:
def generate_gaussian_noise_images(num_images, mean, std, ran_mean, ran_std , label):
    images = []
    for _ in range(num_images):
        image = np.random.normal(mean, std, (1, 224, 224))
        if ran_mean != 0 and ran_std != 0:
            add_image = np.random.normal(ran_mean, ran_std, (1, 224, 224))
            image = image + add_image
        images.append((image, label))
    return images

class GaussianNoiseDataset(VisionDataset):
    def __init__(self, data, transform=None):
        super().__init__('', transform=transform)
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image, label = self.data[index]
        image = torch.from_numpy(image).float()

        if self.transform:
            image = self.transform(image)

        return image, label

In [None]:
num_classes = 40
# 500 samples for each class, 40 classes
num_train_samples = 50
# 200 samples for each class, 40 classes
num_test_samples = 100
mean = 0
std = 1

train_data = generate_gaussian_noise_images(
    num_train_samples, mean, std,0,0,label=0)
test_data = generate_gaussian_noise_images(
    num_test_samples, mean, std,0,0, label=0)

ran_std_list = []

ran_mean = np.random.uniform(0, 1)
ran_std = np.random.uniform(0.5, 1.5)

ran_mean_a = np.random.uniform(-1, 1)
ran_std_a = np.random.uniform(0.5, 1.5)

for i in range(num_classes-1):
    # ran_mean_a = np.random.uniform(0, 1)
    # ran_mean = mean
    # ran_std_list.append(ran_std)
# ....
    # ran_mean_b = 0
    # ran_mean_a = np.random.uniform(0, 1)
    # ran_mean_b = np.random.uniform(0.1, 1)
    # ran_std = np.random.uniform(0.5, 1.5)
    # ran_std = np.random.uniform(0.5, 1.5)
    train_data += generate_gaussian_noise_images(
        num_train_samples, mean, std, 0, 0, i+1)
    test_data += generate_gaussian_noise_images(
        num_test_samples, mean, std, 0 , 0, i+1)
print(ran_std_list)
train_dataset = GaussianNoiseDataset(train_data, None)
test_dataset = GaussianNoiseDataset(test_data, None)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)


In [None]:
from run.resnet import TesNet

log_dir = get_simple_log_dir()
summary = SummaryWriter(log_dir=log_dir)
device = 'cuda:0'
model = TesNet(num_classes=num_classes,pretrained=False).to(device)
# freeze the embedding layer
# for param in model.fea_e.parameters():
#     param.requires_grad = False
# for param in model.parameters():
#     param.requires_grad = False
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)


In [None]:
num_epochs = 1e6
epoch = 0
best_train_loss = 1000
best_train_acc = 0
patience = 10
while True:
    model.train()
    running_loss = 0.0
    correct = 0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        _,_,outputs = model(inputs.repeat(1, 3, 1, 1))
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        correct += (torch.max(outputs, 1)[1] == labels).sum().item()
        running_loss += loss.item()
        summary.add_scalar('train_loss', loss.item(), epoch * len(train_loader) + i)
        epoch += 1
    train_acc = correct / len(train_loader.dataset)*100
    summary.add_scalar('train_acc', train_acc, epoch)
    # use early stopping to stop training
    # if train_acc > best_train_acc:
    #     best_train_acc = train_acc
    #     patience = 10
    #     print('patience: ', patience, 'epoch: ', epoch)
    # else:
    #     patience -= 1
    #     print('patience: ', patience, 'epoch: ', epoch)
    #     if patience == 0:
    #         break
    
    if epoch > num_epochs:
        break
    
    model.eval()
    correct = 0
    total = 0
    test_loss = 0
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)
            _,_,outputs = model(images.repeat(1, 3, 1, 1))
            loss = criterion(outputs, labels)
            test_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    # print test acc %, test loss
    # print(f'Epoch {epoch + 1}, Test Accuracy: %.3f %%' % (100 * correct / total))
    summary.add_scalar('test_acc', 100 * correct / total, epoch)
    summary.add_scalar('test_loss', test_loss / len(test_loader), epoch)

print('Finished Training')