In [None]:
!wget https://raw.githubusercontent.com/christophschuhmann/improved-aesthetic-predictor/main/sac+logos+ava1-l14-linearMSE.pth

In [None]:
!pip install "pytorch_lightning==1.8.6"

In [None]:
import pytorch_lightning as pl
import torch
import torch.nn as nn


class MLP(pl.LightningModule):
    def __init__(self, input_size):
        super().__init__()
        self.input_size = input_size
        self.layers = nn.Sequential(
            nn.Linear(self.input_size, 1024),
            #nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, 128),
            #nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            #nn.ReLU(),
            nn.Dropout(0.1),

            nn.Linear(64, 16),
            #nn.ReLU(),

            nn.Linear(16, 1)
        )

    def forward(self, x):
        return self.layers(x)


model = MLP(768)  # CLIP embedding dim is 768 for CLIP ViT L 14

device = "cuda" if torch.cuda.is_available() else "cpu"
state_dict = torch.load("sac+logos+ava1-l14-linearMSE.pth", map_location=device)   # load the model you trained previously or the model available in this repo

model.load_state_dict(state_dict)

model.to(device)
model.eval()

### Bias: `linear(0)`

In [None]:
zero = torch.zeros(model.input_size)

In [None]:
with torch.no_grad():
    bias = model(zero)

print( "Aesthetic score predicted by the model:")
print( bias )

### Weights: `linear(e[i])` for all `i`

In [None]:
one_hots = torch.eye(model.input_size)

In [None]:
with torch.no_grad():
    weights = model(one_hots) - bias

### Alternate model

In [None]:
def alternate(x):
    return x @ weights + bias

### Comparison of Outputs

In [None]:
N_tests = 100
random_inputs = [random_valid_input() for _ in range(N_tests)]

In [None]:
random_valid_input = lambda: torch.randn_like(zero)

In [None]:
def run_model(x):
    with torch.no_grad():
        return model(x)
        
run_alternate = lambda x: alternate(x)

#### Visual

In [None]:
%%time
model_outs = [run_model(x) for x in random_inputs]

In [None]:
%%time
alternate_outs = [run_alternate(x) for x in random_inputs]

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.plot(model_outs, alternate_outs);

In [None]:
def assert_equivalence(embed, run_model, run_alternate):
    model_outs = run_model(embed)
    alternate_outs = run_alternate(embed)
    assert torch.allclose(model_outs, alternate_outs, rtol=1e-3), f"total error of {torch.sum(torch.square(model_outs - alternate_outs))}"

In [None]:
# easy true negatives
assert_equivalence(zero, run_model, run_model)
assert_equivalence(zero, run_model, run_alternate)
assert_equivalence(random_valid_input(), run_model, run_model)

In [None]:
# easy true positive
try:
    assert_equivalence(zero, run_model, lambda x: run_model(x) + 1)
    assert False, "assertion should've failed but didnt"
except AssertionError:
    pass

In [None]:
# real true negatives

for random_input in random_inputs:
    assert_equivalence(random_input, run_model, run_alternate)

### Bonus

In [None]:
plt.hist(weights.T);