In [1]:
import numpy as np
from collections import defaultdict
import copy
np.random.seed=2021

In [2]:
def temp_scaled_softmax(data,temp=1.0):
    prob=np.exp(data/temp)/np.sum(np.exp(data/temp),axis=-1)
    return prob
def greedy_search(prob):
    return np.argmax(prob)
    
def topk_sampling(prob,k=10,verbose=True):
    # this function is only used to process one single example
    if prob.ndim>1:
        prob=prob[0]
    topk_label=np.argsort(prob)[-k:]
    topk_prob=prob[topk_label]/np.sum(prob[topk_label])
    label=np.random.choice(topk_label,p=topk_prob)
    if verbose:
        print('orig_all_prob:{}'.format(prob))
        print('**********************')
        print('topk_sorted_label:{}'.format(topk_label))
        print('topk_sorted_prob:{}'.format(prob[topk_label]))
        print('**********************')
        print('topk_new_prob:{}'.format(topk_prob))
        print('finally sampled label:{}'.format(label))
    return label
def topp_sampling(prob,p=0.9,verbose=True):
    # this function is only used to process one single example
    if prob.ndim>1:
        prob=prob[0]
    sorted_htol_label=np.argsort(prob)[::-1]
    sorted_htol_prob=prob[sorted_htol_label]
    for i in range(prob.shape[0]):
        if np.sum(sorted_htol_prob[:i+1])>=p:
            break
    topp_htol_label=sorted_htol_label[:i+1]
    topp_htol_prob=sorted_htol_prob[topp_htol_label]/np.sum(sorted_htol_prob[topp_htol_label])
    label=np.random.choice(topp_htol_label,p=topp_htol_prob)
    if verbose:
        print('orig_all_prob:{}'.format(prob))
        print('**********************')
        print('topp_sorted_label:{}'.format(topp_htol_label))
        print('topp_sorted_prob:{}'.format(sorted_htol_prob[topp_htol_label]))
        print('**********************')
        print('topp_new_prob:{}'.format(topp_htol_prob))
        print('finally sampled label:{}'.format(label))

In [3]:
test_data=np.array([[3,2,4,1]]).astype(float)
print(temp_scaled_softmax(test_data,temp=1.0))
print(temp_scaled_softmax(test_data,temp=0.8))
print(temp_scaled_softmax(test_data,temp=0.6))
print(temp_scaled_softmax(test_data,temp=0.4))
print(temp_scaled_softmax(test_data,temp=0.2))
print(temp_scaled_softmax(test_data,temp=0.1))
print(temp_scaled_softmax(test_data,temp=0.01))

[[0.23688282 0.08714432 0.64391426 0.0320586 ]]
[[0.20580651 0.05896455 0.71833531 0.01689363]]
[[0.15339683 0.02897292 0.81215798 0.00547228]]
[[7.53504725e-02 6.18514343e-03 9.17956677e-01 5.07707490e-04]]
[[6.69254708e-03 4.50940275e-05 9.93262055e-01 3.03841168e-07]]
[[4.53978686e-05 2.06106005e-09 9.99954600e-01 9.35719813e-14]]
[[3.72007598e-044 1.38389653e-087 1.00000000e+000 5.14820022e-131]]


In [4]:
test_data=np.array([[3,2,4,1]]).astype(float)
prob=temp_scaled_softmax(test_data)
topk_label=topk_sampling(prob,k=3)
print('**********************')
topp_label=topp_sampling(prob,p=0.9)

orig_all_prob:[0.23688282 0.08714432 0.64391426 0.0320586 ]
**********************
topk_sorted_label:[1 0 2]
topk_sorted_prob:[0.08714432 0.23688282 0.64391426]
**********************
topk_new_prob:[0.09003057 0.24472847 0.66524096]
finally sampled label:2
**********************
orig_all_prob:[0.23688282 0.08714432 0.64391426 0.0320586 ]
**********************
topp_sorted_label:[2 0 1]
topp_sorted_prob:[0.08714432 0.64391426 0.23688282]
**********************
topp_new_prob:[0.09003057 0.66524096 0.24472847]
finally sampled label:1


In [5]:
class beam_structure(object):
    def __init__(self,decode_start_input=0):
        self.storage=[(decode_start_input,1.0)]
    def add_item(self,label,prob):
        self.storage.append((label,prob))
    def get_total_prob(self,):
        labels,probs=zip(*self.storage)
        return np.exp(np.sum(np.log(probs)))
    def get_all_labels(self,):
        labels,probs=zip(*self.storage)
        return labels
    def get_all_probs(self,):
        labels,probs=zip(*self.storage)
        return probs
class assume_model(object):
    def __init__(self,label_dim=5):
        self.label_dim=label_dim
    def __call__(self,inputs):
        return np.random.randn(self.label_dim)
def beam_search(model,beam_num=3,max_num=10,softmax_temp=1.0):
    index=0
    total_beams=[beam_structure() for _ in range(beam_num)]
    while index<max_num:
        all_current_beams=[]
        for i in range(beam_num):
            inputs=total_beams[i].get_all_labels()
            outputs=model(inputs)
            prob=temp_scaled_softmax(outputs,softmax_temp)
            topk_labels=np.argsort(prob)[-beam_num:]
            topk_probs=prob[topk_labels]
            for label,prob in zip(topk_labels,topk_probs):
                new_beam=copy.deepcopy(total_beams[0])
                new_beam.add_item(label,prob)
                all_current_beams.append(new_beam)
            if index==0:
                break
        exist_data=defaultdict(list)
        filtered_beams=[]
        for current_beam in all_current_beams:
            label,prob=current_beam.storage[-1]
            if label not in exist_data:
                exist_data[label].append(current_beam)
        all_current_beams=[]
        for label,current_beams in exist_data.items():
            all_current_beams.append(sorted(current_beams,key=lambda x:x.storage[-1][-1],reverse=True)[0])         
        total_beams=sorted(all_current_beams,key=lambda x:x.get_total_prob(),reverse=True)[:beam_num]
        for i in range(beam_num):
            print('step:{},label_prob:{},accumulate_prob:{}'.format(index+1,total_beams[i].storage,total_beams[i].get_total_prob()))
        index+=1
    optimized_beam=sorted(total_beams,key=lambda x:x.get_total_prob(),reverse=True)[0]
    return optimized_beam.get_all_labels()   

In [6]:
model=assume_model(label_dim=5)
beam_search(model,beam_num=3,max_num=5,softmax_temp=0.7)

step:1,label_prob:[(0, 1.0), (4, 0.8689966256413102)],accumulate_prob:0.8689966256413102
step:1,label_prob:[(0, 1.0), (2, 0.05894956921520326)],accumulate_prob:0.05894956921520325
step:1,label_prob:[(0, 1.0), (0, 0.04397906199868869)],accumulate_prob:0.043979061998688694
step:2,label_prob:[(0, 1.0), (4, 0.8689966256413102), (4, 0.555428782677956)],accumulate_prob:0.48266573793120443
step:2,label_prob:[(0, 1.0), (4, 0.8689966256413102), (0, 0.4347996574747333)],accumulate_prob:0.37783943517554075
step:2,label_prob:[(0, 1.0), (4, 0.8689966256413102), (3, 0.1929366105828902)],accumulate_prob:0.16766126355920305
step:3,label_prob:[(0, 1.0), (4, 0.8689966256413102), (4, 0.555428782677956), (2, 0.6982224192211296)],accumulate_prob:0.3370080392134773
step:3,label_prob:[(0, 1.0), (4, 0.8689966256413102), (4, 0.555428782677956), (1, 0.4004969512456286)],accumulate_prob:0.19330615651216893
step:3,label_prob:[(0, 1.0), (4, 0.8689966256413102), (4, 0.555428782677956), (0, 0.2853610473005622)],accu

(0, 4, 4, 2, 4, 0)