In [3]:
import numpy as np
import torch
import torchvision.transforms as transforms
from torchvision.datasets import DatasetFolder
from PIL import Image
from torch.utils.data import ConcatDataset,DataLoader,Subset,Dataset
import gc

#数据增强
train_tfm = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(180),
    transforms.RandomGrayscale(),
    transforms.ToTensor(),
])

test_tfm = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

batch_size = 128

def load_pic(x):
    return Image.open(x)

train_set_1 = DatasetFolder("drive/MyDrive/food-11/training/labeled", lambda x: Image.open(x), extensions="jpg", transform=train_tfm)
train_set_2 = DatasetFolder("drive/MyDrive/food-11/training/labeled", lambda x: Image.open(x), extensions="jpg", transform=test_tfm)
valid_set = DatasetFolder("drive/MyDrive/food-11/validation", lambda x: Image.open(x), extensions="jpg", transform=test_tfm)
unlabeled_set = DatasetFolder("drive/MyDrive/food-11/training/unlabeled", lambda x: Image.open(x), extensions="jpg", transform=train_tfm)

test_set = DatasetFolder("drive/MyDrive/food-11/testing", lambda x: Image.open(x), extensions="jpg", transform=test_tfm)
train_set = ConcatDataset([train_set_1, train_set_2]) 
print(len(train_set))
# Construct data loaders.
train_loader = DataLoader(train_set, batch_size=batch_size,  num_workers=8, pin_memory=True,shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=batch_size,  num_workers=8, pin_memory=True,shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

KeyboardInterrupt: ignored

In [4]:
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn as nn

class Mydataset(Dataset):
  def __init__(self, x ,y):
    self.data=x
    self.label=y
  def __getitem__(self,idx):
    return self.data[idx],self.label[idx]
  def __len__(self):
    return len(self.data)
#半监督学习打伪标签
def pseudo_label(model,unlabel_set):
  model.eval()
  print("k")
  with torch.no_grad():
    data_loader=DataLoader(unlabel_set,batch_size=128,shuffle=False)
    for j,data in enumerate(data_loader):
      inputs,_=data
      inputs=inputs.to("cuda")
      outputs=model(inputs)
      softmax=nn.Softmax(dim=-1)
      labels=softmax(outputs)
      _,labels=torch.max(labels,dim=1)
      
      if j==0:
        dataset=Mydataset(x=inputs,y=labels)
      else:
        dataset_2=Mydataset(inputs,labels)
        dataset=ConcatDataset([dataset,dataset_2])
      gc.collect()
      torch.cuda.empty_cache()
  return dataset


#网络结构
class Net(nn.Module):
    def __init__(self):
      super(Net,self).__init__()
      self.cn=nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),

            nn.Conv2d(64, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),

            nn.Conv2d(128, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(4, 4, 0),
            
        )
      self.fc=nn.Sequential(
            nn.Linear(256 * 8 * 8, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 11)
        )
    def forward(self,x):
        x=self.cn(x)
        x = x.flatten(1)
        x=self.fc(x)
        return x

LEARNING_RATE=0.0003
EPOCH=80

model=Net().to("cuda")
criterion=nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=LEARNING_RATE)





train_acc_list=[]
val_acc_list=[]
i=0
semi=False
for i in range(EPOCH):
    train_acc=0.0
    train_loss=0.0
    val_acc=0.0
    val_loss=0.0
    
    #网络训练
    model.train()
    for j,data in enumerate(train_loader):
      inputs,labels=data
      inputs=inputs.to("cuda")
      labels=labels.to("cuda")
      outputs=model(inputs)
      loss=criterion(outputs,labels)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
        
      train_loss+=loss
      _,pred=torch.max(outputs,dim=1)
      train_acc+=(pred.cpu()==labels.cpu()).sum().item()
    gc.collect()
    torch.cuda.empty_cache()
    
    #半监督学习
    if semi:
      print("s")
      for k,data in enumerate(unlabeled_loader):
        optimizer.zero_grad()
        inputs,labels=data
        inputs=inputs.to("cuda")
        labels=labels.to("cuda")
        outputs=model(inputs)
        loss=criterion(outputs,labels)
        loss.backward()
        optimizer.step()
    gc.collect()
    torch.cuda.empty_cache()

    #验证集
    model.eval()
    with torch.no_grad():
      for j,data in enumerate(valid_loader):
        inputs,labels=data
        inputs=inputs.to("cuda")
        labels=labels.to("cuda")
        outputs=model(inputs)
        loss=criterion(outputs,labels)
        val_loss+=loss
        _,pred=torch.max(outputs,dim=1)
        val_acc+=(pred.cpu()==labels.cpu()).sum().item()
    
    #准确度和损失值    
    train_acc=train_acc/len(train_set)
    val_acc=val_acc/len(valid_set)
    train_loss=train_loss/len(train_loader)
    val_loss=val_loss/len(valid_loader)
    train_acc_list.append(train_acc)
    val_acc_list.append(val_acc)
    print(i," t",train_acc,"v",val_acc)
    
    #打伪标签
    if(val_acc>1 and semi==False ):
      unlabeled_sett=pseudo_label(model,unlabeled_set)
      unlabeled_loader=DataLoader(unlabeled_sett, batch_size=batch_size, num_workers=0, shuffle=False)
      semi=True

plt.figure()
x=range(0,80)
plt.plot(x,train_acc_list, 'r-')
plt.plot(x,val_acc_list,'r-',color='red')
plt.show()


KeyboardInterrupt: ignored