In [1]:
import torch
import pandas as pd
from PIL import Image 
from torchvision import transforms,utils
from collections import OrderedDict
import numpy as np 
import torch.nn.functional as  F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from PIL import Image
import os
from torch.utils.data import SubsetRandomSampler
import torchvision.models as models

In [2]:
def default_loader(path):
    return Image.open(path).convert('RGB')
class MyDataset(Dataset):
    def __init__(self, txt,type, transform=None, target_transform=None, loader=default_loader):
        fh = open(txt, 'r')
        imgs = []
        self.type = type
        for line in fh:
            line = line.strip('\n')
            line = line.rstrip()
            words = line.split()  #分割成文件名和标签
            if self.type== "train":
                imgs.append((words[0],int(words[1])))
            else:
                imgs.append(words[0])
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        if self.type== "train":
            fn, label = self.imgs[index]
        else:
            fn = self.imgs[index]
        img = self.loader(fn)
        if self.transform is not None:
            img = self.transform(img)
        if self.type== "train":
            return img,label
        else:
            return img

    def __len__(self):
        return len(self.imgs)

In [3]:
transform_train = transforms.Compose([
    transforms.Resize(256),
    transforms.ColorJitter(),
    transforms.RandomCrop(224),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize((0.5961016, 0.4565982, 0.39084524),(0.21863548, 0.19483651, 0.18572323))
])
transform_test = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.5961016, 0.4565982, 0.39084524),(0.21863548, 0.19483651, 0.18572323))
])

In [4]:
#switch device to gpu if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [5]:
#定义模型和损失
model = models.vgg16(pretrained = True)
model

KeyboardInterrupt: 

In [None]:
# Freeze the parameters 
#for param in model.parameters():
    #param.requires_grad = False 


#Classifier architecture to put on top of resnet18

fc = torch.nn.Sequential(
    torch.nn.Linear(25088, 1000),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.4),
    torch.nn.Linear(1000, 100),
    torch.nn.ReLU(),
    torch.nn.Linear(100, 2)
)
model.classifier = fc
optimizer = optim.SGD(model.parameters(), lr=0.01,momentum=0.9)
criterion = torch.nn.CrossEntropyLoss() 
model.to(device)

In [None]:
train_data=MyDataset(txt='../input/mydata/train1.txt',type = "train", transform=transform_train)
#train_loader = DataLoader(train_data, batch_size=100,shuffle=True)
#测试集
test_data=MyDataset(txt='../input/mydata/test1.txt', type = "test",transform=transform_test)
test_loader = DataLoader(test_data, batch_size=100,shuffle=False)

#划分训练集为测试集（0.2）和训练集（0.8）
batch_size = 60
validation_split = 0.2
shuffle_dataset = True
random_seed= 42
dataset_size = len(train_data)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))  #np.floor向下取整
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)#打乱顺序
train_indices, val_indices = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, 
                                           sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
                                                sampler=valid_sampler)

In [None]:
def adjust_learning_rate(optimizer, epoch, lr):
    if epoch<4:
        lr = 0.01
    if epoch >=4 and epoch<6:
        lr = 0.005
    if epoch>6:
        lr = lr*0.8
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
def train(epochs):
    train_loss =[]
    for e in range(epochs):
        running_loss =0
        pp = 0
        for id,data in enumerate(train_loader):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            img = model(inputs)
            loss = criterion(img, labels)
            p = loss.item()
            running_loss+=p
            if pp%30==0:
                print('epoch:',epochs)
                print("loss:",p)
                pp = 0
            loss.backward()
            pp+=1
            optimizer.step()
    #plt.plot(train_loss,label="Training Loss")
    #plt.show()
def  test_on_traindata():
    with torch.no_grad():
        correct = 0
        total = 0
        for id,data in enumerate(validation_loader):
            input,label = data
            input,label = input.to(device),label.to(device)
            output = model(input)
            _,predict = torch.max(output,dim=1)
            #print(predict.shape)
            #print(label.shape)
            total += label.size(0)
            correct += (predict==label).sum().item()
    
    print('Accuracy on test set:%d %%' % (100*correct/total)) 

def  test():
    with torch.no_grad():
        pre = np.array([],dtype = np.int64)
        pre = torch.from_numpy(pre)
        pre = pre.to(device)
        for id,input in enumerate(test_loader):
            input = input.to(device)
            output = model(input)
            _,predict = torch.max(output,dim=1)
            pre = torch.cat((pre,predict),0)
            #print(pre.data)  
        #写入csv
        pre = pre.cpu()
        pre = pre.numpy()
        res = pd.DataFrame({"label":pre})
        print('res=',res)
        return res
epochs = 5
model.train() 
for i in range(9):
    print('epoch: ',i)
    train(i)
    test_on_traindata()
    lr = optimizer.param_groups[0]['lr']
    adjust_learning_rate(optimizer, i, lr)
    print('lr = ',lr)
res = test()
res.to_csv("./submit3.csv")