In [1]:
import os
from torch.utils.data import DataLoader
import torch
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import clear_output 
from torchvision import transforms
import torch.optim.lr_scheduler as lr_scheduler

from models.allConv import AllConv  
from models.mlpContrastive import MLPContrastive
from trainer.train import Trainer
from losses.puLoss import PULoss
from dataTools.mnist import MNIST_Chainer, load_dataset, CIFAR10_Chainer

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "4, 5, 6, 7"

# 加载数据集

In [3]:
SEED = 0
torch.manual_seed(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}
XYtrain, XYtest, prior = load_dataset("cifar10", 10000, 40000)
prior = torch.tensor(prior)

(50000, 3, 32, 32)
training:(50000, 3, 32, 32) consist of 10000 labeled positive samples and 40000 unlabeled samples
test:(10000, 3, 32, 32)


In [4]:
batch_size = 1024
n_gpu = 2
total_batch_size = n_gpu * batch_size

dataset = {'train': CIFAR10_Chainer(XYtrain),
           'valid': CIFAR10_Chainer(XYtest)}           
dataloader = {'train': DataLoader(dataset['train'], batch_size= total_batch_size, shuffle= True, drop_last= True, **kwargs),       # drop_last= True
              'validtrain': DataLoader(dataset['train'], batch_size= total_batch_size, shuffle= False, **kwargs),
              'valid': DataLoader(dataset['valid'], batch_size= total_batch_size, shuffle= False, **kwargs)}

# print(prior)
lr = 0.01 #0.0001
n_epochs   = 300
kwargs2 = {
          'train_Dataloader': dataloader['train'],
          'valid_Dataloader': dataloader['valid'],
          'validtrain_Dataloader': dataloader['validtrain'],
          'epochs': n_epochs,   
          'n_gpu': n_gpu,
          'notebook': True,        
          }
# print(kwargs2)


# nnPU

In [5]:
model = AllConv().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr= lr, weight_decay=0.005)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=kwargs2["epochs"])

trainer_nnPU  = Trainer('nnPU', 
                    model,
                    device, 
                    PULoss(prior= prior, nnPU= True),
                    prior,
                    optimizer,
                    lr_scheduler = scheduler,
                    **kwargs2)

In [6]:
trainer_nnPU.run_trainer()
print(trainer_nnPU.criterion.number_of_negative_loss)
clear_output()

Progress:   0%|          | 0/300 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Validation:   0%|          | 0/5 [00:00<?, ?it/s]

  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)


Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

Training:   0%|          | 0/24 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt
plt.plot(trainer_nnPU.valid_acc)

In [None]:
trainer_nnPU.run_validate('/root/userfolder/projects/biomed-clip-puNCE/Reproduce/mynnPU/checkpoints/checkpoint_20230713164346.pth')
