In [3]:
!python3 -m pip install torch==1.9.0 torchvision==0.10.0

Defaulting to user installation because normal site-packages is not writeable
Collecting torch==1.9.0
  Downloading torch-1.9.0-cp39-cp39-manylinux1_x86_64.whl.metadata (25 kB)
Collecting torchvision==0.10.0
  Downloading torchvision-0.10.0-cp39-cp39-manylinux1_x86_64.whl.metadata (7.9 kB)
Downloading torch-1.9.0-cp39-cp39-manylinux1_x86_64.whl (831.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m831.4/831.4 MB[0m [31m636.0 kB/s[0m eta [36m0:00:00[0m00:01[0m0:03[0mm
[?25hDownloading torchvision-0.10.0-cp39-cp39-manylinux1_x86_64.whl (22.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m22.1/22.1 MB[0m [31m11.9 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: torch, torchvision
  Attempting uninstall: torch
    Found existing installation: torch 2.4.0
    Uninstalling torch-2.4.0:
      Successfully uninstalled torch-2.4.0
  You can safely remove it manually.[0m[33m
[0m  Attempting uninstall: torchvisi

In [1]:
!which pip

/mnt/net/a1x256-ai01/hotel/mattlee/myproject/myenv/bin/pip


In [8]:
import os
import json
from PIL import Image
import torch
import torchvision
from torchvision import transforms
from pytorch_pretrained_vit import ViT
import sqlite3


In [3]:
conn = sqlite3.connect('image_classifications.db')
cursor = conn.cursor()
cursor.execute('''
    CREATE TABLE IF NOT EXISTS predictions (
        id INTEGER PRIMARY KEY AUTOINCREMENT,
        class_name TEXT,
        image_path TEXT,
        top1 TEXT, top1_conf REAL,
        top2 TEXT, top2_conf REAL,
        top3 TEXT, top3_conf REAL,
        top4 TEXT, top4_conf REAL,
        top5 TEXT, top5_conf REAL
    )
''')
conn.commit()


In [4]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available and will be used.")
    print(f"GPU model: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("GPU is not available, using CPU instead.")
 
x = torch.tensor([1.0, 2.0, 3.0])
x = x.to(device)
print(f"Tensor: {x}")

GPU is available and will be used.
GPU model: Quadro RTX 8000
Tensor: tensor([1., 2., 3.], device='cuda:0')


In [9]:

model_name = 'B_16_imagenet1k'
model = ViT(model_name, pretrained=True)
model.eval()
    
index_tosynset_label = json.load(open('ImageNet_class_index.json'))
index_to_classname = json.load(open('imagenet-simple-labels.json'))


tfms = transforms.Compose([
    transforms.Resize(model.image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5]),
])

data_dir = 'final_val/'
correct, total = 0,0

for synset_folder in sorted(os.listdir(data_dir)):
    
    synset_path = os.path.join(data_dir, synset_folder)
    if not os.path.isdir(synset_path):
        continue
    class_name = index_to_classname[total]
   
    print(class_name + "," + synset_folder)
    img_name = next((f for f in os.listdir(synset_path) if os.path.isfile(os.path.join(synset_path, f))), None)
    
    if img_name is None:
        continue
    img_path = os.path.join(synset_path, img_name)
    img = Image.open(img_path).convert('RGB')
    img = tfms(img).unsqueeze(0) 
 
    with torch.no_grad():
        outputs = model(img).squeeze(0)

        
      
        
    predicted_idx = torch.argmax(outputs).item()
    predicted_idx_str = str(predicted_idx)
        # Map to synset and human-readable label
    predicted_synset, predicted_label = index_tosynset_label[predicted_idx_str]

        # Compare the predicted synset with the folder's synset
    if predicted_synset == synset_folder:
        correct += 1
        
    total += 1
    
    top5_indices = torch.topk(outputs, k=5).indices.tolist()
    top5_probs = torch.softmax(outputs, dim=0)[top5_indices].tolist()

    top5_labels = [index_tosynset_label[str(idx)][1] for idx in top5_indices]
    
    cursor.execute('''INSERT INTO predictions (class_name, image_path, top1, top1_conf,
                                            top2, top2_conf, top3, top3_conf,
                                            top4, top4_conf, top5, top5_conf)
                             VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)''',
                        (class_name, img_path, top5_labels[0], top5_probs[0],
                        top5_labels[1], top5_probs[1],
                        top5_labels[2], top5_probs[2],
                        top5_labels[3], top5_probs[3],
                        top5_labels[4], top5_probs[4]))
    conn.commit()

        
        
accuracy = correct / total * 100
print(f'Accuracy: {accuracy:.2f}%')
conn.close()
        
        
            
            

Loaded pretrained weights.
tench,n01440764


ProgrammingError: Cannot operate on a closed database.

In [10]:
def fetch_sample(offset):
    # connect to the sqlite3 database
    conn = sqlite3.connect('image_classifications.db')
    c = conn.cursor()
 
    c.execute(f'SELECT * FROM predictions LIMIT 1 OFFSET {offset}')
    record = c.fetchone()
 
    # check if a record was found
    if record:
        print("Sample Record from Database:")
        print(f"ID: {record[0]}")
        print(f"Class Name: {record[1]}")
        print(f"Image Path: {record[2]}")
        print(f"Top 1 Prediction: {record[3]} with confidence {record[4]:.2f}")
        print(f"Top 2 Prediction: {record[5]} with confidence {record[6]:.2f}")
        print(f"Top 3 Prediction: {record[7]} with confidence {record[8]:.2f}")
        print(f"Top 4 Prediction: {record[9]} with confidence {record[10]:.2f}")
        print(f"Top 5 Prediction: {record[11]} with confidence {record[12]:.2f}")
    else:
        print("No more records found.")
    conn.close()


fetch_sample(0)
fetch_sample(1)
fetch_sample(2)

Sample Record from Database:
ID: 1
Class Name: tench
Image Path: final_val/n01440764/ILSVRC2012_val_00031094.JPEG
Top 1 Prediction: tench with confidence 0.99
Top 2 Prediction: barracouta with confidence 0.01
Top 3 Prediction: reel with confidence 0.00
Top 4 Prediction: pole with confidence 0.00
Top 5 Prediction: coho with confidence 0.00
Sample Record from Database:
ID: 2
Class Name: goldfish
Image Path: final_val/n01443537/ILSVRC2012_val_00028713.JPEG
Top 1 Prediction: goldfish with confidence 1.00
Top 2 Prediction: coral_reef with confidence 0.00
Top 3 Prediction: rock_beauty with confidence 0.00
Top 4 Prediction: tench with confidence 0.00
Top 5 Prediction: lionfish with confidence 0.00
Sample Record from Database:
ID: 3
Class Name: great white shark
Image Path: final_val/n01484850/ILSVRC2012_val_00017194.JPEG
Top 1 Prediction: great_white_shark with confidence 0.97
Top 2 Prediction: tiger_shark with confidence 0.02
Top 3 Prediction: hammerhead with confidence 0.00
Top 4 Prediction