In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
import timm
from PIL import Image
import os

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
print(torch.cuda.is_available())

In [8]:
!pip install py7zr
import py7zr
with py7zr.SevenZipFile('./input/cifar-10/test.7z', mode='r') as z:
    z.extractall('./working')



In [9]:
transform = transforms.Compose(
    [transforms.Resize((224, 224)),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

from torch.utils.data import Subset
num_samples = 30000  # 取り出したいサンプル数
indices = torch.randperm(len(trainset))[:num_samples]  # ランダムにインデックスを選択
trainset = Subset(trainset, indices)  # サブセットを作成

trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                          shuffle=True, num_workers=4)
model = timm.create_model('vit_base_patch16_224', pretrained=True)
model.head = nn.Linear(model.head.in_features, 10)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
model = model.to(device)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:05<00:00, 29078327.17it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

In [10]:
from tqdm import tqdm

for epoch in range(5):  # loop over the dataset multiple times
    progress_bar = tqdm(enumerate(trainloader, 0), total=len(trainloader), desc="Epoch: %d" % epoch)
    for i, data in progress_bar:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # Update the progress bar
        progress_bar.set_postfix({'loss': loss.item()})

Epoch: 0: 100%|██████████| 469/469 [09:05<00:00,  1.16s/it, loss=0.0893] 
Epoch: 1: 100%|██████████| 469/469 [09:05<00:00,  1.16s/it, loss=0.0571]  
Epoch: 2: 100%|██████████| 469/469 [09:05<00:00,  1.16s/it, loss=0.0961]  
Epoch: 3: 100%|██████████| 469/469 [09:05<00:00,  1.16s/it, loss=0.00191] 
Epoch: 4: 100%|██████████| 469/469 [09:06<00:00,  1.16s/it, loss=0.000246]


In [None]:
def num2label(fileNum):
    filename = "{}.png".format(fileNum)
    image = Image.open(os.path.join("test", filename))
    data = transform(image)
    with torch.no_grad():
        data = data.unsqueeze(0).to(device)
        outputs = model(data)
        _, predicted = torch.max(outputs.data, 1)
#         print(predicted)
        if predicted==0:
            return "airplane"
        elif predicted==1:
            return "automobile"
        elif predicted ==2:
            return "bird"
        elif predicted ==3:
            return "cat"
        elif predicted ==4:
            return "deer"
        elif predicted ==5:
            return "dog"
        elif predicted ==6:
            return "frog"
        elif predicted ==7:
            return "horse"
        elif predicted ==8:
            return "ship"
        elif predicted ==9:
            return "truck"
        else:
            print("error")
        

In [None]:
import csv

# Define the data
header = ['id', 'label']

# Write the data to a CSV file
with open('myfile.csv', 'w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(header)
    for i in range(1,300000):
        ans = num2label(i)
        print(i,ans)
        writer.writerow([i,ans])

1 bird
2 airplane
3 automobile
4 ship
5 airplane
6 cat
7 airplane
8 cat
9 ship
10 cat
11 bird
12 horse
13 horse
14 cat
15 dog
16 airplane
17 dog
18 bird
19 ship
20 deer
21 ship
22 ship
23 ship
24 automobile
25 frog
26 truck
27 automobile
28 cat
29 bird
30 cat
31 cat
32 horse
33 dog
34 dog
35 deer
36 deer
37 cat
38 horse
39 ship
40 cat
41 bird
42 bird
43 truck
44 deer
45 horse
46 automobile
47 frog
48 bird
49 airplane
50 truck
51 bird
52 horse
53 dog
54 frog
55 ship
56 ship
57 frog
58 ship
59 airplane
60 cat
61 dog
62 cat
63 deer
64 frog
65 ship
66 truck
67 ship
68 airplane
69 cat
70 frog
71 truck
72 bird
73 truck
74 frog
75 frog
76 horse
77 airplane
78 frog
79 horse
80 cat
81 cat
82 deer
83 ship
84 cat
85 deer
86 frog
87 dog
88 bird
89 frog
90 truck
91 automobile
92 airplane
93 cat
94 horse
95 automobile
96 frog
97 deer
98 bird
99 truck
100 bird
101 deer
102 airplane
103 cat
104 cat
105 ship
106 airplane
107 automobile
108 bird
109 frog
110 frog
111 horse
112 bird
113 deer
114 airplane