In [11]:
!pip install torch torchvision timm numpy pyyaml pillow kagglehub

Collecting kagglehub
  Downloading kagglehub-0.3.12-py3-none-any.whl.metadata (38 kB)
Downloading kagglehub-0.3.12-py3-none-any.whl (67 kB)
Installing collected packages: kagglehub
Successfully installed kagglehub-0.3.12


In [38]:
import os
import zipfile
import requests
from tqdm import tqdm

def download_with_progress(url, save_path):
    # Send GET request
    response = requests.get(url, stream=True)
    response.raise_for_status()  # Check for HTTP errors
    
    # Get total file size in bytes
    total_size = int(response.headers.get('content-length', 0))
    
    # Download with progress bar
    with open(save_path, 'wb') as f, tqdm(
        desc=os.path.basename(save_path),
        total=total_size,
        unit='iB',
        unit_scale=True,
        unit_divisor=1024,
    ) as bar:
        for data in response.iter_content(chunk_size=1024):
            size = f.write(data)
            bar.update(size)

# Create data directory if needed
data_dir = 'data'
os.makedirs(data_dir, exist_ok=True)

# Download parameters
url = 'https://www.kaggle.com/api/v1/datasets/download/luckyhathaway/tiny-imagenet-c'
zip_path = os.path.join(data_dir, 'tiny-imagenet-c.zip')
extract_path = os.path.join(data_dir)

print("Downloading dataset...")
download_with_progress(url, zip_path)

# Extract with progress
print("\nExtracting dataset...")
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    file_list = zip_ref.namelist()
    for file in tqdm(file_list, desc="Extracting"):
        zip_ref.extract(file, extract_path)

# Clean up
os.remove(zip_path)

print(f"\nDataset successfully downloaded and extracted to {extract_path}")

Downloading dataset...


tiny-imagenet-c.zip: 100%|██████████████████| 1.25G/1.25G [03:04<00:00, 7.30MiB/s]



Extracting dataset...


Extracting: 100%|███████████████████████| 750000/750000 [01:21<00:00, 9213.34it/s]



Dataset successfully downloaded and extracted to data


In [3]:
import torch
from models.coca import COCA
from data.imagenet_c import ImageNetC
from scripts.test_accuracy import test_accuracy
from models.resnet import resnet50
from models.vit import vit_base_patch16_224
from utils.augmentations import get_transform

In [5]:
# 参数设置
args = {
    'data_root': './data/Tiny-ImageNet-C',
    'batch_size': 32,
    'workers': 4,
    'corruption': 'gaussian_noise',
    'severity': 5,
    'lr_anchor': 0.001,
    'lr_aux': 0.00025,
    'momentum': 0.9
}
print(args)

{'data_root': './data/Tiny-ImageNet-C', 'batch_size': 32, 'workers': 4, 'corruption': 'gaussian_noise', 'severity': 5, 'lr_anchor': 0.001, 'lr_aux': 0.00025, 'momentum': 0.9}


In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
anchor_model = vit_base_patch16_224(pretrained=True).to(device)
aux_model = resnet50(pretrained=True).to(device)
coca = COCA(anchor_model, aux_model, lr_anchor=args['lr_anchor'], lr_aux=args['lr_aux'], momentum=args['momentum'])

transform_anchor = get_transform('vit_base_patch16_224')
transform_aux = get_transform('resnet50')
dataset = ImageNetC(
    root=args['data_root'],
    corruption_type=args['corruption'],
    severity=args['severity'],
    transform_anchor=transform_anchor,
    transform_aux=transform_aux
)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=args['batch_size'], num_workers=args['workers'], shuffle=True)

for i, (images_anchor, images_aux, _) in enumerate(data_loader):
    images_anchor = images_anchor.to(device)
    images_aux = images_aux.to(device)
    coca.update(images_anchor, images_aux)
    if (i+1) % 10 == 0:
        print(f'Adapted on batch {i+1}/{len(data_loader)}')

accuracy = test_accuracy(coca, args['data_root'], args['batch_size'], args['workers'], args['corruption'], args['severity'])
print(f'Accuracy on {args['corruption']} (severity {args['severity']}): {accuracy:.2f}%')

Added blocks.0.norm1.weight to optimizer
Added blocks.0.norm1.bias to optimizer
Added blocks.0.norm2.weight to optimizer
Added blocks.0.norm2.bias to optimizer
Added blocks.1.norm1.weight to optimizer
Added blocks.1.norm1.bias to optimizer
Added blocks.1.norm2.weight to optimizer
Added blocks.1.norm2.bias to optimizer
Added blocks.2.norm1.weight to optimizer
Added blocks.2.norm1.bias to optimizer
Added blocks.2.norm2.weight to optimizer
Added blocks.2.norm2.bias to optimizer
Added blocks.3.norm1.weight to optimizer
Added blocks.3.norm1.bias to optimizer
Added blocks.3.norm2.weight to optimizer
Added blocks.3.norm2.bias to optimizer
Added blocks.4.norm1.weight to optimizer
Added blocks.4.norm1.bias to optimizer
Added blocks.4.norm2.weight to optimizer
Added blocks.4.norm2.bias to optimizer
Added blocks.5.norm1.weight to optimizer
Added blocks.5.norm1.bias to optimizer
Added blocks.5.norm2.weight to optimizer
Added blocks.5.norm2.bias to optimizer
Added blocks.6.norm1.weight to optimizer

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x13adcbce0>
Traceback (most recent call last):
  File "/opt/anaconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/opt/anaconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1582, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/opt/anaconda3/lib/python3.12/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/lib/python3.12/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/lib/python3.12/multiprocessing/connection.py", line 1136, in wait
    ready = selector.select(timeout)
            ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/lib/python3.12/selectors.py", line 415, in select
    fd_e